diff options
60 files changed, 2708 insertions, 628 deletions
@@ -44,8 +44,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.63.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.63.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.64.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.64.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -56,18 +56,18 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty-shaded</artifactId> - <version>1.63.0</version> + <version>1.64.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-protobuf</artifactId> - <version>1.63.0</version> + <version>1.64.0</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-stub</artifactId> - <version>1.63.0</version> + <version>1.64.0</version> </dependency> <dependency> <!-- necessary for Java 9+ --> <groupId>org.apache.tomcat</groupId> @@ -79,18 +79,18 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -runtimeOnly 'io.grpc:grpc-netty-shaded:1.63.0' -implementation 'io.grpc:grpc-protobuf:1.63.0' -implementation 'io.grpc:grpc-stub:1.63.0' +runtimeOnly 'io.grpc:grpc-netty-shaded:1.64.0' +implementation 'io.grpc:grpc-protobuf:1.64.0' +implementation 'io.grpc:grpc-stub:1.64.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.63.0' -implementation 'io.grpc:grpc-protobuf-lite:1.63.0' -implementation 'io.grpc:grpc-stub:1.63.0' +implementation 'io.grpc:grpc-okhttp:1.64.0' +implementation 'io.grpc:grpc-protobuf-lite:1.64.0' +implementation 'io.grpc:grpc-stub:1.64.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` @@ -99,7 +99,7 @@ For [Bazel](https://bazel.build), you can either (with the GAVs from above), or use `@io_grpc_grpc_java//api` et al (see below). [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.63.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.64.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -131,7 +131,7 @@ For protobuf-based codegen integrated with the Maven build system, you can use <configuration> <protocArtifact>com.google.protobuf:protoc:3.25.1:exe:${os.detected.classifier}</protocArtifact> <pluginId>grpc-java</pluginId> - <pluginArtifact>io.grpc:protoc-gen-grpc-java:1.63.0:exe:${os.detected.classifier}</pluginArtifact> + <pluginArtifact>io.grpc:protoc-gen-grpc-java:1.64.0:exe:${os.detected.classifier}</pluginArtifact> </configuration> <executions> <execution> @@ -161,7 +161,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.63.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.64.0' } } generateProtoTasks { @@ -194,7 +194,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.63.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.64.0' } } generateProtoTasks { diff --git a/api/src/main/java/io/grpc/ClientStreamTracer.java b/api/src/main/java/io/grpc/ClientStreamTracer.java index 55751dd3f..2f366b740 100644 --- a/api/src/main/java/io/grpc/ClientStreamTracer.java +++ b/api/src/main/java/io/grpc/ClientStreamTracer.java @@ -70,10 +70,23 @@ public abstract class ClientStreamTracer extends StreamTracer { } /** - * Trailing metadata has been received from the server. + * Headers has been received from the server. This method does not pass ownership to {@code + * headers}, so implementations must not access the metadata after returning. Modifications to the + * metadata within this method will be seen by interceptors and the application. * - * @param trailers the mutable trailing metadata. Modifications to it will be seen by - * interceptors and the application. + * @param headers the received header metadata + */ + public void inboundHeaders(Metadata headers) { + inboundHeaders(); + } + + /** + * Trailing metadata has been received from the server. This method does not pass ownership to + * {@code trailers}, so implementations must not access the metadata after returning. + * Modifications to the metadata within this method will be seen by interceptors and the + * application. + * + * @param trailers the received trailing metadata * @since 1.17.0 */ public void inboundTrailers(Metadata trailers) { diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index 7710924d8..f4775c79a 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -47,6 +47,7 @@ import io.grpc.binder.internal.OneWayBinderProxies.BlockingBinderDecorator; import io.grpc.binder.internal.OneWayBinderProxies.ThrowingOneWayBinderProxy; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ObjectPool; @@ -142,34 +143,25 @@ public final class BinderClientTransportTest { } private class BinderClientTransportBuilder { - private SecurityPolicy securityPolicy = SecurityPolicies.internalOnly(); - private OneWayBinderProxy.Decorator binderDecorator = OneWayBinderProxy.IDENTITY_DECORATOR; + final BinderClientTransportFactory.Builder factoryBuilder = new BinderClientTransportFactory.Builder() + .setSourceContext(appContext) + .setScheduledExecutorPool(executorServicePool) + .setOffloadExecutorPool(executorServicePool); public BinderClientTransportBuilder setSecurityPolicy(SecurityPolicy securityPolicy) { - this.securityPolicy = securityPolicy; + factoryBuilder.setSecurityPolicy(securityPolicy); return this; } public BinderClientTransportBuilder setBinderDecorator( OneWayBinderProxy.Decorator binderDecorator) { - this.binderDecorator = binderDecorator; + factoryBuilder.setBinderDecorator(binderDecorator); return this; } public BinderTransport.BinderClientTransport build() { - return new BinderTransport.BinderClientTransport( - appContext, - BinderChannelCredentials.forDefault(), - serverAddress, - null, - BindServiceFlags.DEFAULTS, - ContextCompat.getMainExecutor(appContext), - executorServicePool, - executorServicePool, - securityPolicy, - InboundParcelablePolicy.DEFAULT, - binderDecorator, - Attributes.EMPTY); + return factoryBuilder.buildClientTransportFactory() + .newClientTransport(serverAddress, new ClientTransportOptions(), null); } } diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java index a7f91fae0..a78eb2887 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java @@ -24,11 +24,12 @@ import io.grpc.ServerStreamTracer; import io.grpc.binder.AndroidComponentAddress; import io.grpc.binder.BindServiceFlags; import io.grpc.binder.BinderChannelCredentials; -import io.grpc.binder.BinderInternal; import io.grpc.binder.HostServices; import io.grpc.binder.InboundParcelablePolicy; import io.grpc.binder.SecurityPolicies; import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; @@ -68,12 +69,11 @@ public final class BinderTransportTest extends AbstractTransportTest { protected InternalServer newServer(List<ServerStreamTracer.Factory> streamTracerFactories) { AndroidComponentAddress addr = HostServices.allocateService(appContext); - BinderServer binderServer = new BinderServer(addr, - executorServicePool, - streamTracerFactories, - BinderInternal.createPolicyChecker(SecurityPolicies.serverInternalOnly()), - InboundParcelablePolicy.DEFAULT, - /* transportSecurityShutdownListener=*/ () -> {}); + BinderServer binderServer = new BinderServer.Builder() + .setListenAddress(addr) + .setExecutorServicePool(executorServicePool) + .setStreamTracerFactories(streamTracerFactories) + .build(); HostServices.configureService(addr, HostServices.serviceParamsBuilder() @@ -97,19 +97,19 @@ public final class BinderTransportTest extends AbstractTransportTest { @Override protected ManagedClientTransport newClientTransport(InternalServer server) { AndroidComponentAddress addr = (AndroidComponentAddress) server.getListenSocketAddress(); + BinderClientTransportFactory.Builder builder = new BinderClientTransportFactory.Builder() + .setSourceContext(appContext) + .setScheduledExecutorPool(executorServicePool) + .setOffloadExecutorPool(offloadExecutorPool); + + ClientTransportOptions options = new ClientTransportOptions(); + options.setEagAttributes(eagAttrs()); + options.setChannelLogger(transportLogger()); + return new BinderTransport.BinderClientTransport( - appContext, - BinderChannelCredentials.forDefault(), + builder.buildClientTransportFactory(), addr, - null, - BindServiceFlags.DEFAULTS, - ContextCompat.getMainExecutor(appContext), - executorServicePool, - offloadExecutorPool, - SecurityPolicies.internalOnly(), - InboundParcelablePolicy.DEFAULT, - OneWayBinderProxy.IDENTITY_DECORATOR, - eagAttrs()); + options); } @Test diff --git a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java index 133c8c5dd..67d3631b9 100644 --- a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java @@ -22,27 +22,14 @@ import static com.google.common.base.Preconditions.checkState; import android.content.Context; import android.os.UserHandle; import androidx.annotation.RequiresApi; -import androidx.core.content.ContextCompat; import com.google.errorprone.annotations.DoNotCall; -import io.grpc.ChannelCredentials; -import io.grpc.ChannelLogger; import io.grpc.ExperimentalApi; import io.grpc.ForwardingChannelBuilder; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; -import io.grpc.binder.internal.BinderTransport; -import io.grpc.binder.internal.OneWayBinderProxy; -import io.grpc.internal.ClientTransportFactory; -import io.grpc.internal.ConnectionClientTransport; +import io.grpc.binder.internal.BinderClientTransportFactory; import io.grpc.internal.FixedObjectPool; -import io.grpc.internal.GrpcUtil; import io.grpc.internal.ManagedChannelImplBuilder; -import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.SharedResourcePool; -import java.net.SocketAddress; -import java.util.Collection; -import java.util.Collections; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -179,14 +166,8 @@ public final class BinderChannelBuilder } private final ManagedChannelImplBuilder managedChannelImplBuilder; + private final BinderClientTransportFactory.Builder transportFactoryBuilder; - private Executor mainThreadExecutor; - private ObjectPool<ScheduledExecutorService> schedulerPool = - SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); - private SecurityPolicy securityPolicy; - private InboundParcelablePolicy inboundParcelablePolicy; - private BindServiceFlags bindServiceFlags; - @Nullable private UserHandle targetUserHandle; private boolean strictLifecycleManagement; private BinderChannelBuilder( @@ -194,41 +175,22 @@ public final class BinderChannelBuilder @Nullable String target, Context sourceContext, BinderChannelCredentials channelCredentials) { - mainThreadExecutor = - ContextCompat.getMainExecutor(checkNotNull(sourceContext, "sourceContext")); - securityPolicy = SecurityPolicies.internalOnly(); - inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; - bindServiceFlags = BindServiceFlags.DEFAULTS; - - final class BinderChannelTransportFactoryBuilder - implements ClientTransportFactoryBuilder { - @Override - public ClientTransportFactory buildClientTransportFactory() { - return new TransportFactory( - sourceContext, - channelCredentials, - mainThreadExecutor, - schedulerPool, - managedChannelImplBuilder.getOffloadExecutorPool(), - securityPolicy, - targetUserHandle, - bindServiceFlags, - inboundParcelablePolicy); - } - } + transportFactoryBuilder = new BinderClientTransportFactory.Builder() + .setSourceContext(sourceContext) + .setChannelCredentials(channelCredentials); if (directAddress != null) { managedChannelImplBuilder = new ManagedChannelImplBuilder( directAddress, directAddress.getAuthority(), - new BinderChannelTransportFactoryBuilder(), + transportFactoryBuilder, null); } else { managedChannelImplBuilder = new ManagedChannelImplBuilder( target, - new BinderChannelTransportFactoryBuilder(), + transportFactoryBuilder, null); } idleTimeout(60, TimeUnit.SECONDS); @@ -242,7 +204,7 @@ public final class BinderChannelBuilder /** Specifies certain optional aspects of the underlying Android Service binding. */ public BinderChannelBuilder setBindServiceFlags(BindServiceFlags bindServiceFlags) { - this.bindServiceFlags = bindServiceFlags; + transportFactoryBuilder.setBindServiceFlags(bindServiceFlags); return this; } @@ -256,8 +218,8 @@ public final class BinderChannelBuilder */ public BinderChannelBuilder scheduledExecutorService( ScheduledExecutorService scheduledExecutorService) { - schedulerPool = - new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + transportFactoryBuilder.setScheduledExecutorPool( + new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService"))); return this; } @@ -269,7 +231,7 @@ public final class BinderChannelBuilder * @return this */ public BinderChannelBuilder mainThreadExecutor(Executor mainThreadExecutor) { - this.mainThreadExecutor = mainThreadExecutor; + transportFactoryBuilder.setMainThreadExecutor(mainThreadExecutor); return this; } @@ -282,7 +244,7 @@ public final class BinderChannelBuilder * @return this */ public BinderChannelBuilder securityPolicy(SecurityPolicy securityPolicy) { - this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); + transportFactoryBuilder.setSecurityPolicy(securityPolicy); return this; } @@ -300,14 +262,14 @@ public final class BinderChannelBuilder @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") @RequiresApi(30) public BinderChannelBuilder bindAsUser(UserHandle targetUserHandle) { - this.targetUserHandle = targetUserHandle; + transportFactoryBuilder.setTargetUserHandle(targetUserHandle); return this; } /** Sets the policy for inbound parcelable objects. */ public BinderChannelBuilder inboundParcelablePolicy( InboundParcelablePolicy inboundParcelablePolicy) { - this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + transportFactoryBuilder.setInboundParcelablePolicy(inboundParcelablePolicy); return this; } @@ -330,87 +292,10 @@ public final class BinderChannelBuilder return this; } - /** Creates new binder transports. */ - private static final class TransportFactory implements ClientTransportFactory { - private final Context sourceContext; - private final BinderChannelCredentials channelCredentials; - private final Executor mainThreadExecutor; - private final ObjectPool<ScheduledExecutorService> scheduledExecutorPool; - private final ObjectPool<? extends Executor> offloadExecutorPool; - private final SecurityPolicy securityPolicy; - @Nullable private final UserHandle targetUserHandle; - private final BindServiceFlags bindServiceFlags; - private final InboundParcelablePolicy inboundParcelablePolicy; - - private ScheduledExecutorService executorService; - private Executor offloadExecutor; - private boolean closed; - - TransportFactory( - Context sourceContext, - BinderChannelCredentials channelCredentials, - Executor mainThreadExecutor, - ObjectPool<ScheduledExecutorService> scheduledExecutorPool, - ObjectPool<? extends Executor> offloadExecutorPool, - SecurityPolicy securityPolicy, - @Nullable UserHandle targetUserHandle, - BindServiceFlags bindServiceFlags, - InboundParcelablePolicy inboundParcelablePolicy) { - this.sourceContext = sourceContext; - this.channelCredentials = channelCredentials; - this.mainThreadExecutor = mainThreadExecutor; - this.scheduledExecutorPool = scheduledExecutorPool; - this.offloadExecutorPool = offloadExecutorPool; - this.securityPolicy = securityPolicy; - this.targetUserHandle = targetUserHandle; - this.bindServiceFlags = bindServiceFlags; - this.inboundParcelablePolicy = inboundParcelablePolicy; - - executorService = scheduledExecutorPool.getObject(); - offloadExecutor = offloadExecutorPool.getObject(); - } - - @Override - public ConnectionClientTransport newClientTransport( - SocketAddress addr, ClientTransportOptions options, ChannelLogger channelLogger) { - if (closed) { - throw new IllegalStateException("The transport factory is closed."); - } - return new BinderTransport.BinderClientTransport( - sourceContext, - channelCredentials, - (AndroidComponentAddress) addr, - targetUserHandle, - bindServiceFlags, - mainThreadExecutor, - scheduledExecutorPool, - offloadExecutorPool, - securityPolicy, - inboundParcelablePolicy, - OneWayBinderProxy.IDENTITY_DECORATOR, - options.getEagAttributes()); - } - - @Override - public ScheduledExecutorService getScheduledExecutorService() { - return executorService; - } - - @Override - public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { - return null; - } - - @Override - public void close() { - closed = true; - executorService = scheduledExecutorPool.returnObject(executorService); - offloadExecutor = offloadExecutorPool.returnObject(offloadExecutor); - } - - @Override - public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { - return Collections.singleton(AndroidComponentAddress.class); - } + @Override + public ManagedChannel build() { + transportFactoryBuilder.setOffloadExecutorPool( + managedChannelImplBuilder.getOffloadExecutorPool()); + return super.build(); } } diff --git a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java index 158f7947e..af5f9eed7 100644 --- a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java @@ -29,18 +29,13 @@ import io.grpc.ServerBuilder; import io.grpc.binder.internal.BinderServer; import io.grpc.binder.internal.BinderTransportSecurity; import io.grpc.internal.FixedObjectPool; -import io.grpc.internal.GrpcUtil; import io.grpc.internal.ServerImplBuilder; import io.grpc.internal.ObjectPool; -import io.grpc.internal.SharedResourcePool; -import java.io.Closeable; import java.io.File; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; - /** * Builder for a server that services requests from an Android Service. */ @@ -72,28 +67,17 @@ public final class BinderServerBuilder } private final ServerImplBuilder serverImplBuilder; - private ObjectPool<ScheduledExecutorService> schedulerPool = - SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); - private ServerSecurityPolicy securityPolicy; - private InboundParcelablePolicy inboundParcelablePolicy; + private final BinderServer.Builder internalBuilder = new BinderServer.Builder(); private boolean isBuilt; - @Nullable private BinderTransportSecurity.ShutdownListener shutdownListener = null; private BinderServerBuilder( AndroidComponentAddress listenAddress, IBinderReceiver binderReceiver) { - securityPolicy = SecurityPolicies.serverInternalOnly(); - inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + internalBuilder.setListenAddress(listenAddress); serverImplBuilder = new ServerImplBuilder(streamTracerFactories -> { - BinderServer server = new BinderServer( - listenAddress, - schedulerPool, - streamTracerFactories, - BinderInternal.createPolicyChecker(securityPolicy), - inboundParcelablePolicy, - // 'shutdownListener' should have been set by build() - checkNotNull(shutdownListener)); + internalBuilder.setStreamTracerFactories(streamTracerFactories); + BinderServer server = internalBuilder.build(); BinderInternal.setIBinder(binderReceiver, server.getHostBinder()); return server; }); @@ -132,8 +116,8 @@ public final class BinderServerBuilder */ public BinderServerBuilder scheduledExecutorService( ScheduledExecutorService scheduledExecutorService) { - schedulerPool = - new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + internalBuilder.setExecutorServicePool( + new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService"))); return this; } @@ -146,7 +130,7 @@ public final class BinderServerBuilder * @return this */ public BinderServerBuilder securityPolicy(ServerSecurityPolicy securityPolicy) { - this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); + internalBuilder.setServerSecurityPolicy(securityPolicy); return this; } @@ -154,7 +138,7 @@ public final class BinderServerBuilder @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public BinderServerBuilder inboundParcelablePolicy( InboundParcelablePolicy inboundParcelablePolicy) { - this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + internalBuilder.setInboundParcelablePolicy(inboundParcelablePolicy); return this; } @@ -173,7 +157,7 @@ public final class BinderServerBuilder * * @return the new Server */ - @Override // For javadoc refinement only. + @Override public Server build() { // Since we install a final interceptor here, we need to ensure we're only built once. checkState(!isBuilt, "BinderServerBuilder can only be used to build one server instance."); @@ -182,7 +166,7 @@ public final class BinderServerBuilder ObjectPool<? extends Executor> executorPool = serverImplBuilder.getExecutorPool(); Executor executor = executorPool.getObject(); BinderTransportSecurity.installAuthInterceptor(this, executor); - shutdownListener = () -> executorPool.returnObject(executor); + internalBuilder.setShutdownListener(() -> executorPool.returnObject(executor)); return super.build(); } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java new file mode 100644 index 000000000..abaf07b00 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java @@ -0,0 +1,195 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import android.content.Context; +import android.os.UserHandle; +import androidx.core.content.ContextCompat; +import io.grpc.ChannelCredentials; +import io.grpc.ChannelLogger; +import io.grpc.Internal; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.BindServiceFlags; +import io.grpc.binder.BinderChannelCredentials; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.SecurityPolicy; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; + +/** + * Creates new binder transports. + */ +@Internal +public final class BinderClientTransportFactory implements ClientTransportFactory { + final Context sourceContext; + final BinderChannelCredentials channelCredentials; + final Executor mainThreadExecutor; + final ObjectPool<ScheduledExecutorService> scheduledExecutorPool; + final ObjectPool<? extends Executor> offloadExecutorPool; + final SecurityPolicy securityPolicy; + @Nullable + final UserHandle targetUserHandle; + final BindServiceFlags bindServiceFlags; + final InboundParcelablePolicy inboundParcelablePolicy; + final OneWayBinderProxy.Decorator binderDecorator; + + ScheduledExecutorService executorService; + Executor offloadExecutor; + private boolean closed; + + private BinderClientTransportFactory(Builder builder) { + sourceContext = checkNotNull(builder.sourceContext); + channelCredentials = checkNotNull(builder.channelCredentials); + mainThreadExecutor = builder.mainThreadExecutor != null ? + builder.mainThreadExecutor : ContextCompat.getMainExecutor(sourceContext); + scheduledExecutorPool = checkNotNull(builder.scheduledExecutorPool); + offloadExecutorPool = checkNotNull(builder.offloadExecutorPool); + securityPolicy = checkNotNull(builder.securityPolicy); + targetUserHandle = builder.targetUserHandle; + bindServiceFlags = checkNotNull(builder.bindServiceFlags); + inboundParcelablePolicy = checkNotNull(builder.inboundParcelablePolicy); + binderDecorator = checkNotNull(builder.binderDecorator); + + executorService = scheduledExecutorPool.getObject(); + offloadExecutor = offloadExecutorPool.getObject(); + } + + @Override + public BinderTransport.BinderClientTransport newClientTransport( + SocketAddress addr, ClientTransportOptions options, ChannelLogger channelLogger) { + if (closed) { + throw new IllegalStateException("The transport factory is closed."); + } + return new BinderTransport.BinderClientTransport(this, (AndroidComponentAddress) addr, options); + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return executorService; + } + + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + return null; + } + + @Override + public void close() { + closed = true; + executorService = scheduledExecutorPool.returnObject(executorService); + offloadExecutor = offloadExecutorPool.returnObject(offloadExecutor); + } + + @Override + public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return Collections.singleton(AndroidComponentAddress.class); + } + + /** + * Allows fluent construction of ClientTransportFactory. + */ + public static final class Builder implements ClientTransportFactoryBuilder { + // Required. + Context sourceContext; + ObjectPool<? extends Executor> offloadExecutorPool; + + // Optional. + BinderChannelCredentials channelCredentials = BinderChannelCredentials.forDefault(); + Executor mainThreadExecutor; // Default filled-in at build time once sourceContext is decided. + ObjectPool<ScheduledExecutorService> scheduledExecutorPool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + SecurityPolicy securityPolicy = SecurityPolicies.internalOnly(); + @Nullable + UserHandle targetUserHandle; + BindServiceFlags bindServiceFlags = BindServiceFlags.DEFAULTS; + InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + OneWayBinderProxy.Decorator binderDecorator = OneWayBinderProxy.IDENTITY_DECORATOR; + + @Override + public BinderClientTransportFactory buildClientTransportFactory() { + return new BinderClientTransportFactory(this); + } + + public Builder setSourceContext(Context sourceContext) { + this.sourceContext = checkNotNull(sourceContext); + return this; + } + + public Builder setOffloadExecutorPool( + ObjectPool<? extends Executor> offloadExecutorPool) { + this.offloadExecutorPool = checkNotNull(offloadExecutorPool, "offloadExecutorPool"); + return this; + } + + public Builder setChannelCredentials(BinderChannelCredentials channelCredentials) { + this.channelCredentials = checkNotNull(channelCredentials, "channelCredentials"); + return this; + } + + public Builder setMainThreadExecutor(Executor mainThreadExecutor) { + this.mainThreadExecutor = checkNotNull(mainThreadExecutor, "mainThreadExecutor"); + return this; + } + + public Builder setScheduledExecutorPool( + ObjectPool<ScheduledExecutorService> scheduledExecutorPool) { + this.scheduledExecutorPool = checkNotNull(scheduledExecutorPool, "scheduledExecutorPool"); + return this; + } + + public Builder setSecurityPolicy(SecurityPolicy securityPolicy) { + this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); + return this; + } + + public Builder setTargetUserHandle(@Nullable UserHandle targetUserHandle) { + this.targetUserHandle = targetUserHandle; + return this; + } + + public Builder setBindServiceFlags(BindServiceFlags bindServiceFlags) { + this.bindServiceFlags = checkNotNull(bindServiceFlags, "bindServiceFlags"); + return this; + } + + public Builder setInboundParcelablePolicy(InboundParcelablePolicy inboundParcelablePolicy) { + this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + return this; + } + + /** + * Decorates both the "endpoint" and "server" binders, for fault injection. + * + * <p>Optional. If absent, these objects will go undecorated. + */ + public Builder setBinderDecorator(OneWayBinderProxy.Decorator binderDecorator) { + this.binderDecorator = checkNotNull(binderDecorator, "binderDecorator"); + return this; + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java index 72faa33ad..03af19c04 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java @@ -28,10 +28,15 @@ import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalInstrumented; import io.grpc.ServerStreamTracer; import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.BinderInternal; import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.ServerSecurityPolicy; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ObjectPool; import io.grpc.internal.ServerListener; +import io.grpc.internal.SharedResourcePool; import java.io.IOException; import java.net.SocketAddress; import java.util.List; @@ -68,24 +73,14 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. @GuardedBy("this") private boolean shutdown; - /** - * @param transportSecurityShutdownListener represents resources that should be cleaned up once - * the server shuts down. - */ - public BinderServer( - AndroidComponentAddress listenAddress, - ObjectPool<ScheduledExecutorService> executorServicePool, - List<? extends ServerStreamTracer.Factory> streamTracerFactories, - BinderTransportSecurity.ServerPolicyChecker serverPolicyChecker, - InboundParcelablePolicy inboundParcelablePolicy, - BinderTransportSecurity.ShutdownListener transportSecurityShutdownListener) { - this.listenAddress = listenAddress; - this.executorServicePool = executorServicePool; + private BinderServer(Builder builder) { + this.listenAddress = checkNotNull(builder.listenAddress); + this.executorServicePool = builder.executorServicePool; this.streamTracerFactories = - ImmutableList.copyOf(checkNotNull(streamTracerFactories, "streamTracerFactories")); - this.serverPolicyChecker = checkNotNull(serverPolicyChecker, "serverPolicyChecker"); - this.inboundParcelablePolicy = inboundParcelablePolicy; - this.transportSecurityShutdownListener = transportSecurityShutdownListener; + ImmutableList.copyOf(checkNotNull(builder.streamTracerFactories, "streamTracerFactories")); + this.serverPolicyChecker = BinderInternal.createPolicyChecker(builder.serverSecurityPolicy); + this.inboundParcelablePolicy = builder.inboundParcelablePolicy; + this.transportSecurityShutdownListener = builder.shutdownListener; hostServiceBinder = new LeakSafeOneWayBinder(this); } @@ -169,4 +164,84 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. } return false; } + + /** Fluent builder of {@link BinderServer} instances. */ + public static class Builder { + @Nullable AndroidComponentAddress listenAddress; + @Nullable List<? extends ServerStreamTracer.Factory> streamTracerFactories; + + ObjectPool<ScheduledExecutorService> executorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + ServerSecurityPolicy serverSecurityPolicy = SecurityPolicies.serverInternalOnly(); + InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + BinderTransportSecurity.ShutdownListener shutdownListener = () -> {}; + + public BinderServer build() { + return new BinderServer(this); + } + + /** + * Sets the "listen" address for this server. + * + * <p>This is somewhat of a grpc-java formality. Binder servers don't really listen, rather, + * Android creates and destroys them according to client needs. + * + * <p>Required. + */ + public Builder setListenAddress(AndroidComponentAddress listenAddress) { + this.listenAddress = listenAddress; + return this; + } + + /** + * Sets the source for {@link ServerStreamTracer}s that will be installed on all new streams. + * + * <p>Required. + */ + public Builder setStreamTracerFactories(List<? extends ServerStreamTracer.Factory> streamTracerFactories) { + this.streamTracerFactories = streamTracerFactories; + return this; + } + + /** + * Sets the executor to be used for scheduling channel timers. + * + * <p>Optional. A process-wide default executor will be used if unset. + */ + public Builder setExecutorServicePool( + ObjectPool<ScheduledExecutorService> executorServicePool) { + this.executorServicePool = checkNotNull(executorServicePool, "executorServicePool"); + return this; + } + + /** + * Sets the {@link ServerSecurityPolicy} to be used for built servers. + * + * Optional, {@link SecurityPolicies#serverInternalOnly()} is the default. + */ + public Builder setServerSecurityPolicy(ServerSecurityPolicy serverSecurityPolicy) { + this.serverSecurityPolicy = checkNotNull(serverSecurityPolicy, "serverSecurityPolicy"); + return this; + } + + /** + * Sets the {@link InboundParcelablePolicy} to be used for built servers. + * + * Optional, {@link InboundParcelablePolicy#DEFAULT} is the default. + */ + public Builder setInboundParcelablePolicy(InboundParcelablePolicy inboundParcelablePolicy) { + this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + return this; + } + + /** + * Installs a callback that will be invoked when this server is {@link #shutdown()} + * + * <p>Optional. + */ + public Builder setShutdownListener(BinderTransportSecurity.ShutdownListener shutdownListener) { + this.shutdownListener = checkNotNull(shutdownListener, "shutdownListener"); + return this; + } + } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 4a33adb21..2703e0ae9 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -28,7 +28,6 @@ import android.os.Parcel; import android.os.Process; import android.os.RemoteException; import android.os.TransactionTooLargeException; -import android.os.UserHandle; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; import com.google.common.base.Verify; @@ -47,11 +46,10 @@ import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.BindServiceFlags; -import io.grpc.binder.BinderChannelCredentials; import io.grpc.binder.InboundParcelablePolicy; import io.grpc.binder.SecurityPolicy; import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.FailingClientStream; import io.grpc.internal.GrpcAttributes; @@ -573,41 +571,34 @@ public abstract class BinderTransport /** * Constructs a new transport instance. * - * @param binderDecorator used to decorate both the "endpoint" and "server" binders, for fault - * injection. + * @param factory parameters common to all a Channel's transports + * @param targetAddress the fully resolved and load-balanced server address + * @param options other parameters that can vary as transports come and go within a Channel */ public BinderClientTransport( - Context sourceContext, - BinderChannelCredentials channelCredentials, + BinderClientTransportFactory factory, AndroidComponentAddress targetAddress, - @Nullable UserHandle targetUserHandle, - BindServiceFlags bindServiceFlags, - Executor mainThreadExecutor, - ObjectPool<ScheduledExecutorService> executorServicePool, - ObjectPool<? extends Executor> offloadExecutorPool, - SecurityPolicy securityPolicy, - InboundParcelablePolicy inboundParcelablePolicy, - OneWayBinderProxy.Decorator binderDecorator, - Attributes eagAttrs) { + ClientTransportOptions options) { super( - executorServicePool, - buildClientAttributes(eagAttrs, sourceContext, targetAddress, inboundParcelablePolicy), - binderDecorator, - buildLogId(sourceContext, targetAddress)); - this.offloadExecutorPool = offloadExecutorPool; - this.securityPolicy = securityPolicy; + factory.scheduledExecutorPool, + buildClientAttributes(options.getEagAttributes(), + factory.sourceContext, targetAddress, factory.inboundParcelablePolicy), + factory.binderDecorator, + buildLogId(factory.sourceContext, targetAddress)); + this.offloadExecutorPool = factory.offloadExecutorPool; + this.securityPolicy = factory.securityPolicy; this.offloadExecutor = offloadExecutorPool.getObject(); numInUseStreams = new AtomicInteger(); pingTracker = new PingTracker(Ticker.systemTicker(), (id) -> sendPing(id)); serviceBinding = new ServiceBinding( - mainThreadExecutor, - sourceContext, - channelCredentials, + factory.mainThreadExecutor, + factory.sourceContext, + factory.channelCredentials, targetAddress.asBindIntent(), - targetUserHandle, - bindServiceFlags.toInteger(), + factory.targetUserHandle, + factory.bindServiceFlags.toInteger(), this); } diff --git a/binder/src/main/java/io/grpc/binder/internal/Inbound.java b/binder/src/main/java/io/grpc/binder/internal/Inbound.java index 5ab96085a..23f11ccda 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Inbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Inbound.java @@ -579,7 +579,7 @@ abstract class Inbound<L extends StreamListener> implements StreamListener.Messa @GuardedBy("this") protected void handlePrefix(int flags, Parcel parcel) throws StatusException { Metadata headers = MetadataHelper.readMetadata(parcel, attributes); - statsTraceContext.clientInboundHeaders(); + statsTraceContext.clientInboundHeaders(headers); listener.headersRead(headers); } diff --git a/core/build.gradle b/core/build.gradle index 22c68b211..f8a95c372 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -32,7 +32,6 @@ dependencies { libraries.truth, project(':grpc-testing') testImplementation testFixtures(project(':grpc-api')), - project(':grpc-inprocess'), project(':grpc-testing') testImplementation libraries.guava.testlib diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index e929716e8..51c31993f 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -304,7 +304,7 @@ public abstract class AbstractClientStream extends AbstractStream */ protected void inboundHeadersReceived(Metadata headers) { checkState(!statusReported, "Received headers on closed stream"); - statsTraceCtx.clientInboundHeaders(); + statsTraceCtx.clientInboundHeaders(headers); boolean compressedStream = false; String streamEncoding = headers.get(CONTENT_ENCODING_KEY); diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index dc4bce014..6eebfdd0f 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -140,7 +140,7 @@ final class DelayedClientTransport implements ManagedClientTransport { } } // This picker's conclusion is "buffer". If there hasn't been a newer picker set (possible - // race with reprocess()), we will buffer it. Otherwise, will try with the new picker. + // race with reprocess()), we will buffer the RPC. Otherwise, will try with the new picker. synchronized (lock) { PickerState newerState = pickerState; if (state == newerState) { diff --git a/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java index b3e9b216d..e7679ea14 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java +++ b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java @@ -50,6 +50,11 @@ public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { } @Override + public void inboundHeaders(Metadata headers) { + delegate().inboundHeaders(headers); + } + + @Override public void inboundTrailers(Metadata trailers) { delegate().inboundTrailers(trailers); } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 09ca4684b..b21fc97e6 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -2081,6 +2081,12 @@ final class ManagedChannelImpl extends ManagedChannel implements @Override public void transportInUse(final boolean inUse) { inUseStateAggregator.updateObjectInUse(delayedTransport, inUse); + if (inUse) { + // It's possible to be in idle mode while inUseStateAggregator is in-use, if one of the + // subchannels is in use. But we should never be in idle mode when delayed transport is in + // use. + exitIdleMode(); + } } @Override diff --git a/core/src/main/java/io/grpc/internal/StatsTraceContext.java b/core/src/main/java/io/grpc/internal/StatsTraceContext.java index 889be30e7..650f0b979 100644 --- a/core/src/main/java/io/grpc/internal/StatsTraceContext.java +++ b/core/src/main/java/io/grpc/internal/StatsTraceContext.java @@ -101,9 +101,9 @@ public final class StatsTraceContext { * * <p>Called from abstract stream implementations. */ - public void clientInboundHeaders() { + public void clientInboundHeaders(Metadata headers) { for (StreamTracer tracer : tracers) { - ((ClientStreamTracer) tracer).inboundHeaders(); + ((ClientStreamTracer) tracer).inboundHeaders(headers); } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index d6ae0a532..ce446bbab 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -46,7 +46,6 @@ import io.grpc.MetricSink; import io.grpc.NameResolver; import io.grpc.NameResolverRegistry; import io.grpc.StaticTestingClassLoader; -import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; @@ -367,7 +366,7 @@ public class ManagedChannelImplBuilderTest { when(mockClientTransportFactoryBuilder.buildClientTransportFactory()) .thenReturn(mockClientTransportFactory); when(mockClientTransportFactory.getSupportedSocketAddressTypes()) - .thenReturn(Collections.singleton(InProcessSocketAddress.class)); + .thenReturn(Collections.singleton(CustomSocketAddress.class)); builder = new ManagedChannelImplBuilder(DUMMY_AUTHORITY_VALID, mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)); @@ -782,4 +781,6 @@ public class ManagedChannelImplBuilderTest { assertFalse(uriPattern.matcher("a,:/").matches()); // ',' not matched assertFalse(uriPattern.matcher(" a:/").matches()); // space not matched } + + private static class CustomSocketAddress extends SocketAddress {} } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java index 98300bc82..a0bd388b1 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java @@ -22,8 +22,8 @@ import static org.junit.Assert.fail; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; -import io.grpc.inprocess.InProcessSocketAddress; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; import java.util.Collections; import org.junit.Test; @@ -110,7 +110,7 @@ public class ManagedChannelImplGetNameResolverTest { try { ManagedChannelImplBuilder.getNameResolverProvider( "testscheme:///foo.googleapis.com:8080", nameResolverRegistry, - Collections.singleton(InProcessSocketAddress.class)); + Collections.singleton(CustomSocketAddress.class)); fail("Should fail"); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().isEqualTo( @@ -196,4 +196,6 @@ public class ManagedChannelImplGetNameResolverTest { @Override public void shutdown() {} } + + private static class CustomSocketAddress extends SocketAddress {} } diff --git a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java index 57d870575..103c12475 100644 --- a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java @@ -351,6 +351,26 @@ public abstract class AbstractTransportTest { } @Test + public void clientShutdownBeforeStartRunnable() throws Exception { + server.start(serverListener); + client = newClientTransport(server); + Runnable runnable = client.start(mockClientTransportListener); + // Shutdown before calling 'runnable' + client.shutdown(Status.UNAVAILABLE.withDescription("shutdown called")); + runIfNotNull(runnable); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); + // We should verify that clients don't call transportReady() after transportTerminated(), but + // transports do this today and nothing cares. ServerImpl, on the other hand, doesn't appreciate + // the out-of-order calls. + MockServerTransportListener serverTransportListener + = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + // Allow any status as some transports (e.g., Netty) don't communicate the original status when + // shutdown while handshaking. It won't be used anyway, so no big deal. + verify(mockClientTransportListener).transportShutdown(any(Status.class)); + } + + @Test public void clientStartAndStopOnceConnected() throws Exception { server.start(serverListener); client = newClientTransport(server); @@ -2251,6 +2271,7 @@ public abstract class AbstractTransportTest { @Override public Attributes transportReady(Attributes attributes) { + assertFalse(terminated.isDone()); return Attributes.newBuilder() .setAll(attributes) .set(ADDITIONAL_TRANSPORT_ATTR_KEY, "additional attribute value") diff --git a/core/src/testFixtures/java/io/grpc/internal/FakeClock.java b/core/src/testFixtures/java/io/grpc/internal/FakeClock.java index 9cc9178f1..1a3584f4e 100644 --- a/core/src/testFixtures/java/io/grpc/internal/FakeClock.java +++ b/core/src/testFixtures/java/io/grpc/internal/FakeClock.java @@ -188,7 +188,8 @@ public final class FakeClock { } @Override public boolean isShutdown() { - throw new UnsupportedOperationException(); + // If shutdown is not implemented, then it is never shutdown. + return false; } @Override public boolean isTerminated() { diff --git a/examples/example-reflection/README.md b/examples/example-reflection/README.md index 9bd91f3ed..801a27343 100644 --- a/examples/example-reflection/README.md +++ b/examples/example-reflection/README.md @@ -45,7 +45,7 @@ Output ### List all the methods of a service ``` - $ grpcurl -plaintext localhost:50051 helloworld.Greeter + $ grpcurl -plaintext localhost:50051 list helloworld.Greeter ``` Output ``` diff --git a/gcp-csm-observability/build.gradle b/gcp-csm-observability/build.gradle new file mode 100644 index 000000000..4f1262165 --- /dev/null +++ b/gcp-csm-observability/build.gradle @@ -0,0 +1,31 @@ +plugins { + id "java-library" + + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: GCP CSM Observability" + +tasks.named("jar").configure { + manifest { + attributes('Automatic-Module-Name': 'io.grpc.gcp.csm.observability') + } +} + +dependencies { + implementation project(':grpc-api'), + project(':grpc-core'), + project(':grpc-opentelemetry'), + project(':grpc-protobuf'), + project(':grpc-xds'), + libraries.guava.jre, // jre version pulled in via xds + libraries.protobuf.java, + libraries.opentelemetry.gcp.resources, + libraries.opentelemetry.sdk.extension.autoconfigure // opentelemetry.gcp.resources uses compileOnly for this dep + testImplementation project(":grpc-testing"), + project(":grpc-inprocess"), + libraries.opentelemetry.sdk.testing, + libraries.assertj.core // opentelemetry.sdk.testing uses compileOnly for this dep + + signature libraries.signature.java +} diff --git a/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/CsmObservability.java b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/CsmObservability.java new file mode 100644 index 000000000..c345fb35d --- /dev/null +++ b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/CsmObservability.java @@ -0,0 +1,160 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.ExperimentalApi; +import io.grpc.InternalConfigurator; +import io.grpc.InternalConfiguratorRegistry; +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.grpc.opentelemetry.InternalGrpcOpenTelemetry; +import io.opentelemetry.api.OpenTelemetry; +import java.io.Closeable; +import java.util.Collection; +import java.util.Collections; + +/** + * The entrypoint for GCP's CSM OpenTelemetry metrics functionality in gRPC. + * + * <p>CsmObservability uses {@link io.opentelemetry.api.OpenTelemetry} APIs for instrumentation. + * When no SDK is explicitly added no telemetry data will be collected. See + * {@code io.opentelemetry.sdk.OpenTelemetrySdk} for information on how to construct the SDK. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11249") +public final class CsmObservability implements Closeable { + private final GrpcOpenTelemetry delegate; + private final MetadataExchanger exchanger; + + public static Builder newBuilder() { + return new Builder(); + } + + private CsmObservability(Builder builder) { + this.delegate = builder.delegate.build(); + this.exchanger = builder.exchanger; + } + + /** + * Registers CsmObservability globally, applying its configuration to all subsequently created + * gRPC channels and servers. + * + * <p>Note: Only one of CsmObservability and GrpcOpenTelemetry instance can be registered + * globally. Any subsequent call to {@code registerGlobal()} will throw an {@code + * IllegalStateException}. + */ + public void registerGlobal() { + InternalConfiguratorRegistry.setConfigurators(Collections.singletonList( + new InternalConfigurator() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder<?> channelBuilder) { + CsmObservability.this.configureChannelBuilder(channelBuilder); + } + + @Override + public void configureServerBuilder(ServerBuilder<?> serverBuilder) { + CsmObservability.this.configureServerBuilder(serverBuilder); + } + })); + } + + @VisibleForTesting + void configureChannelBuilder(ManagedChannelBuilder<?> builder) { + delegate.configureChannelBuilder(builder); + } + + @VisibleForTesting + void configureServerBuilder(ServerBuilder<?> serverBuilder) { + delegate.configureServerBuilder(serverBuilder); + exchanger.configureServerBuilder(serverBuilder); + } + + @Override + public void close() {} + + /** + * Builder for configuring {@link CsmObservability}. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11249") + public static final class Builder { + private final GrpcOpenTelemetry.Builder delegate = GrpcOpenTelemetry.newBuilder(); + private final MetadataExchanger exchanger; + + private Builder() { + this(new MetadataExchanger()); + } + + @VisibleForTesting + Builder(MetadataExchanger exchanger) { + this.exchanger = exchanger; + InternalGrpcOpenTelemetry.builderPlugin(delegate, exchanger); + } + + /** + * Sets the {@link io.opentelemetry.api.OpenTelemetry} entrypoint to use. This can be used to + * configure OpenTelemetry by returning the instance created by a + * {@code io.opentelemetry.sdk.OpenTelemetrySdkBuilder}. + */ + public Builder sdk(OpenTelemetry sdk) { + delegate.sdk(sdk); + return this; + } + + /** + * Adds optionalLabelKey to all the metrics that can provide value for the + * optionalLabelKey. + */ + public Builder addOptionalLabel(String optionalLabelKey) { + delegate.addOptionalLabel(optionalLabelKey); + return this; + } + + /** + * Enables the specified metrics for collection and export. By default, only a subset of + * metrics are enabled. + */ + public Builder enableMetrics(Collection<String> enableMetrics) { + delegate.enableMetrics(enableMetrics); + return this; + } + + /** + * Disables the specified metrics from being collected and exported. + */ + public Builder disableMetrics(Collection<String> disableMetrics) { + delegate.disableMetrics(disableMetrics); + return this; + } + + /** + * Disable all metrics. If set to true all metrics must be explicitly enabled. + */ + public Builder disableAllMetrics() { + delegate.disableAllMetrics(); + return this; + } + + /** + * Returns a new {@link CsmObservability} built with the configuration of this {@link + * Builder}. + */ + public CsmObservability build() { + return new CsmObservability(this); + } + } +} diff --git a/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java new file mode 100644 index 000000000..24e035bdc --- /dev/null +++ b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java @@ -0,0 +1,364 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.CallOptions; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.Metadata; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.internal.JsonParser; +import io.grpc.internal.JsonUtil; +import io.grpc.opentelemetry.InternalOpenTelemetryPlugin; +import io.grpc.protobuf.ProtoUtils; +import io.grpc.xds.ClusterImplLoadBalancerProvider; +import io.grpc.xds.InternalGrpcBootstrapperImpl; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.contrib.gcp.resource.GCPResourceProvider; +import java.net.URI; +import java.util.Map; +import java.util.function.Consumer; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * OpenTelemetryPlugin implementing metadata-based workload property exchange for both client and + * server. Is responsible for determining the metadata, communicating the metadata, and adding local + * and remote details to metrics. + */ +final class MetadataExchanger implements InternalOpenTelemetryPlugin { + private static final Logger logger = Logger.getLogger(MetadataExchanger.class.getName()); + + private static final AttributeKey<String> CLOUD_PLATFORM = + AttributeKey.stringKey("cloud.platform"); + private static final AttributeKey<String> K8S_NAMESPACE_NAME = + AttributeKey.stringKey("k8s.namespace.name"); + private static final AttributeKey<String> K8S_CLUSTER_NAME = + AttributeKey.stringKey("k8s.cluster.name"); + private static final AttributeKey<String> CLOUD_AVAILABILITY_ZONE = + AttributeKey.stringKey("cloud.availability_zone"); + private static final AttributeKey<String> CLOUD_REGION = + AttributeKey.stringKey("cloud.region"); + private static final AttributeKey<String> CLOUD_ACCOUNT_ID = + AttributeKey.stringKey("cloud.account.id"); + + private static final Metadata.Key<String> SEND_KEY = + Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key<Struct> RECV_KEY = + Metadata.Key.of("x-envoy-peer-metadata", new BinaryToAsciiMarshaller<>( + ProtoUtils.metadataMarshaller(Struct.getDefaultInstance()))); + + private static final String EXCHANGE_TYPE = "type"; + private static final String EXCHANGE_CANONICAL_SERVICE = "canonical_service"; + private static final String EXCHANGE_PROJECT_ID = "project_id"; + private static final String EXCHANGE_LOCATION = "location"; + private static final String EXCHANGE_CLUSTER_NAME = "cluster_name"; + private static final String EXCHANGE_NAMESPACE_NAME = "namespace_name"; + private static final String EXCHANGE_WORKLOAD_NAME = "workload_name"; + private static final String TYPE_GKE = "gcp_kubernetes_engine"; + private static final String TYPE_GCE = "gcp_compute_engine"; + + private final String localMetadata; + private final Attributes localAttributes; + + public MetadataExchanger() { + this( + new GCPResourceProvider().getAttributes(), + System::getenv, + InternalGrpcBootstrapperImpl::getJsonContent); + } + + MetadataExchanger(Attributes platformAttributes, Lookup env, Supplier<String> xdsBootstrap) { + String type = platformAttributes.get(CLOUD_PLATFORM); + String canonicalService = env.get("CSM_CANONICAL_SERVICE_NAME"); + Struct.Builder struct = Struct.newBuilder(); + put(struct, EXCHANGE_TYPE, type); + put(struct, EXCHANGE_CANONICAL_SERVICE, canonicalService); + if (TYPE_GKE.equals(type)) { + String location = platformAttributes.get(CLOUD_AVAILABILITY_ZONE); + if (location == null) { + location = platformAttributes.get(CLOUD_REGION); + } + put(struct, EXCHANGE_WORKLOAD_NAME, env.get("CSM_WORKLOAD_NAME")); + put(struct, EXCHANGE_NAMESPACE_NAME, platformAttributes.get(K8S_NAMESPACE_NAME)); + put(struct, EXCHANGE_CLUSTER_NAME, platformAttributes.get(K8S_CLUSTER_NAME)); + put(struct, EXCHANGE_LOCATION, location); + put(struct, EXCHANGE_PROJECT_ID, platformAttributes.get(CLOUD_ACCOUNT_ID)); + } else if (TYPE_GCE.equals(type)) { + String location = platformAttributes.get(CLOUD_AVAILABILITY_ZONE); + if (location == null) { + location = platformAttributes.get(CLOUD_REGION); + } + put(struct, EXCHANGE_WORKLOAD_NAME, env.get("CSM_WORKLOAD_NAME")); + put(struct, EXCHANGE_LOCATION, location); + put(struct, EXCHANGE_PROJECT_ID, platformAttributes.get(CLOUD_ACCOUNT_ID)); + } + localMetadata = BaseEncoding.base64().encode(struct.build().toByteArray()); + + localAttributes = Attributes.builder() + .put("csm.mesh_id", nullIsUnknown(getMeshId(xdsBootstrap))) + .put("csm.workload_canonical_service", nullIsUnknown(canonicalService)) + .build(); + } + + private static String nullIsUnknown(String value) { + return value == null ? "unknown" : value; + } + + private static void put(Struct.Builder struct, String key, String value) { + value = nullIsUnknown(value); + struct.putFields(key, Value.newBuilder().setStringValue(value).build()); + } + + private static void put(AttributesBuilder attributes, String key, Value value) { + attributes.put(key, nullIsUnknown(fromValue(value))); + } + + private static String fromValue(Value value) { + if (value == null) { + return null; + } + if (value.getKindCase() != Value.KindCase.STRING_VALUE) { + return null; + } + return value.getStringValue(); + } + + @VisibleForTesting + static String getMeshId(Supplier<String> xdsBootstrap) { + try { + @SuppressWarnings("unchecked") + Map<String, ?> rawBootstrap = (Map<String, ?>) JsonParser.parse(xdsBootstrap.get()); + Map<String, ?> node = JsonUtil.getObject(rawBootstrap, "node"); + String id = JsonUtil.getString(node, "id"); + Preconditions.checkNotNull(id, "id"); + String[] parts = id.split("/", 6); + if (!(parts.length == 6 + && parts[0].equals("projects") + && parts[2].equals("networks") + && parts[3].startsWith("mesh:") + && parts[4].equals("nodes"))) { + throw new Exception("node id didn't match mesh format: " + id); + } + return parts[3].substring("mesh:".length()); + } catch (Exception e) { + logger.log(Level.INFO, "Failed to determine mesh ID for CSM", e); + return null; + } + } + + private void addLabels(AttributesBuilder to, Struct struct) { + to.putAll(localAttributes); + Map<String, Value> remote = struct.getFieldsMap(); + Value typeValue = remote.get(EXCHANGE_TYPE); + String type = fromValue(typeValue); + put(to, "csm.remote_workload_type", typeValue); + put(to, "csm.remote_workload_canonical_service", remote.get(EXCHANGE_CANONICAL_SERVICE)); + if (TYPE_GKE.equals(type)) { + put(to, "csm.remote_workload_project_id", remote.get(EXCHANGE_PROJECT_ID)); + put(to, "csm.remote_workload_location", remote.get(EXCHANGE_LOCATION)); + put(to, "csm.remote_workload_cluster_name", remote.get(EXCHANGE_CLUSTER_NAME)); + put(to, "csm.remote_workload_namespace_name", remote.get(EXCHANGE_NAMESPACE_NAME)); + put(to, "csm.remote_workload_name", remote.get(EXCHANGE_WORKLOAD_NAME)); + } else if (TYPE_GCE.equals(type)) { + put(to, "csm.remote_workload_project_id", remote.get(EXCHANGE_PROJECT_ID)); + put(to, "csm.remote_workload_location", remote.get(EXCHANGE_LOCATION)); + put(to, "csm.remote_workload_name", remote.get(EXCHANGE_WORKLOAD_NAME)); + } + } + + @Override + public boolean enablePluginForChannel(String target) { + URI uri; + try { + uri = new URI(target); + } catch (Exception ex) { + return false; + } + String authority = uri.getAuthority(); + return "xds".equals(uri.getScheme()) + && (authority == null || "traffic-director-global.xds.googleapis.com".equals(authority)); + } + + @Override + public ClientCallPlugin newClientCallPlugin() { + return new ClientCallState(); + } + + public void configureServerBuilder(ServerBuilder<?> serverBuilder) { + serverBuilder.intercept(new ServerCallInterceptor()); + } + + @Override + public ServerStreamPlugin newServerStreamPlugin(Metadata inboundMetadata) { + return new ServerStreamState(inboundMetadata.get(RECV_KEY)); + } + + final class ClientCallState implements ClientCallPlugin { + private volatile Value serviceName; + private volatile Value serviceNamespace; + + @Override + public ClientStreamPlugin newClientStreamPlugin() { + return new ClientStreamState(); + } + + @Override + public CallOptions filterCallOptions(CallOptions options) { + Consumer<Map<String, Struct>> existingConsumer = + options.getOption(ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER); + return options.withOption( + ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER, + (Map<String, Struct> clusterMetadata) -> { + metadataConsumer(clusterMetadata); + existingConsumer.accept(clusterMetadata); + }); + } + + private void metadataConsumer(Map<String, Struct> clusterMetadata) { + Struct struct = clusterMetadata.get("com.google.csm.telemetry_labels"); + if (struct == null) { + struct = Struct.getDefaultInstance(); + } + serviceName = struct.getFieldsMap().get("service_name"); + serviceNamespace = struct.getFieldsMap().get("service_namespace"); + } + + @Override + public void addMetadata(Metadata toMetadata) { + toMetadata.put(SEND_KEY, localMetadata); + } + + class ClientStreamState implements ClientStreamPlugin { + private Struct receivedExchange; + + @Override + public void inboundHeaders(Metadata headers) { + setExchange(headers); + } + + @Override + public void inboundTrailers(Metadata trailers) { + if (receivedExchange != null) { + return; // Received headers + } + setExchange(trailers); + } + + private void setExchange(Metadata metadata) { + Struct received = metadata.get(RECV_KEY); + if (received == null) { + receivedExchange = Struct.getDefaultInstance(); + } else { + receivedExchange = received; + } + } + + @Override + public void addLabels(AttributesBuilder to) { + put(to, "csm.service_name", serviceName); + put(to, "csm.service_namespace", serviceNamespace); + Struct exchange = receivedExchange; + if (exchange == null) { + exchange = Struct.getDefaultInstance(); + } + MetadataExchanger.this.addLabels(to, exchange); + } + } + } + + final class ServerCallInterceptor implements ServerInterceptor { + @Override + public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( + ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { + if (!headers.containsKey(RECV_KEY)) { + return next.startCall(call, headers); + } else { + return next.startCall(new SimpleForwardingServerCall<ReqT, RespT>(call) { + private boolean headersSent; + + @Override + public void sendHeaders(Metadata headers) { + headersSent = true; + headers.put(SEND_KEY, localMetadata); + super.sendHeaders(headers); + } + + @Override + public void close(Status status, Metadata trailers) { + if (!headersSent) { + trailers.put(SEND_KEY, localMetadata); + } + super.close(status, trailers); + } + }, headers); + } + } + } + + final class ServerStreamState implements ServerStreamPlugin { + private final Struct receivedExchange; + + ServerStreamState(Struct exchange) { + if (exchange == null) { + exchange = Struct.getDefaultInstance(); + } + receivedExchange = exchange; + } + + @Override + public void addLabels(AttributesBuilder to) { + MetadataExchanger.this.addLabels(to, receivedExchange); + } + } + + interface Lookup { + String get(String name); + } + + interface Supplier<T> { + T get() throws Exception; + } + + static final class BinaryToAsciiMarshaller<T> implements Metadata.AsciiMarshaller<T> { + private final Metadata.BinaryMarshaller<T> delegate; + + public BinaryToAsciiMarshaller(Metadata.BinaryMarshaller<T> delegate) { + this.delegate = Preconditions.checkNotNull(delegate, "delegate"); + } + + @Override + public T parseAsciiString(String serialized) { + return delegate.parseBytes(BaseEncoding.base64().decode(serialized)); + } + + @Override + public String toAsciiString(T value) { + return BaseEncoding.base64().encode(delegate.toBytes(value)); + } + } +} diff --git a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java new file mode 100644 index 000000000..55287a2b9 --- /dev/null +++ b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java @@ -0,0 +1,636 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import static com.google.common.truth.Truth.assertThat; +import static io.opentelemetry.api.common.AttributeKey.stringKey; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.inprocess.InProcessSocketAddress; +import io.grpc.internal.testing.FakeNameResolverProvider; +import io.grpc.stub.ClientCalls; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.ClusterImplLoadBalancerProvider; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link CsmObservability}. */ +@RunWith(JUnit4.class) +public final class CsmObservabilityTest { + @Rule + public final OpenTelemetryRule openTelemetryTesting = OpenTelemetryRule.create(); + @Rule + public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + + private NameResolverProvider fakeNameResolverProvider; + private InProcessSocketAddress socketAddress = new InProcessSocketAddress("csm-test-server"); + private ServerBuilder<?> serverBuilder = InProcessServerBuilder.forAddress(socketAddress) + .addService(voidService(Status.OK)) + .directExecutor(); + + @After + public void tearDown() { + if (fakeNameResolverProvider != null) { + NameResolverRegistry.getDefaultRegistry().deregister(fakeNameResolverProvider); + } + } + + @Test + public void unknownDataExchange() throws Exception { + String xdsBootstrap = ""; + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.<String, String>of()::get, + () -> xdsBootstrap); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.<String, String>of()::get, + () -> xdsBootstrap); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder<?> channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void nonCsmServer() throws Exception { + String xdsBootstrap = ""; + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.<String, String>of()::get, + () -> xdsBootstrap); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder<?> channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + OpenTelemetryAssertions.assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.started") + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.sent_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.rcvd_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.call.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientEndAttributes)))); + } + + @Test + public void nonCsmClient() throws Exception { + String xdsBootstrap = ""; + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .build(), + ImmutableMap.<String, String>of()::get, + () -> xdsBootstrap); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.<String, String>of()::get, + () -> xdsBootstrap); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds://not-a-csm-authority/csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder<?> channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + preexistingClientEndAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void k8sExchange() throws Exception { + // Purposefully use a different project ID in the bootstrap than the resource, as the mesh could + // be in a different project than the running account. + String clientBootstrap = "{\"node\": {" + + "\"id\": \"projects/12/networks/mesh:mymesh/nodes/a6420022-cbc5-4e10-808c-507e3fc95f2e\"" + + "}}"; + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .put(stringKey("k8s.namespace.name"), "namespace-aeiou") + .put(stringKey("k8s.cluster.name"), "mycluster1") + .put(stringKey("cloud.region"), "us-central1") + .put(stringKey("cloud.account.id"), "31415926") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", + "CSM_WORKLOAD_NAME", "best-client")::get, + () -> clientBootstrap); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + String serverBootstrap = "{\"node\": {" + + "\"id\": \"projects/34/networks/mesh:meshhh/nodes/4969ef19-24b6-44c0-baf3-86d188ff5967\"" + + "}}"; + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .put(stringKey("k8s.namespace.name"), "namespace-1e43c") + .put(stringKey("k8s.cluster.name"), "mycluster2") + .put(stringKey("cloud.availability_zone"), "us-east2-c") + .put(stringKey("cloud.region"), "us-east2") + .put(stringKey("cloud.account.id"), "11235813") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", + "CSM_WORKLOAD_NAME", "fast-server")::get, + () -> serverBootstrap); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test-k8s"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder<?> channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor() + .intercept(new ProvideFilterMetadataInterceptor( + ImmutableMap.of("com.google.csm.telemetry_labels", Struct.newBuilder() + .putFields("service_name", + Value.newBuilder().setStringValue("second-server-name").build()) + .putFields("service_namespace", + Value.newBuilder().setStringValue("namespace-0001").build()) + .build()))); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.remote_workload_type"), "gcp_kubernetes_engine") + .put(stringKey("csm.remote_workload_project_id"), "11235813") + .put(stringKey("csm.remote_workload_location"), "us-east2-c") + .put(stringKey("csm.remote_workload_cluster_name"), "mycluster2") + .put(stringKey("csm.remote_workload_namespace_name"), "namespace-1e43c") + .put(stringKey("csm.remote_workload_name"), "fast-server") + .put(stringKey("csm.service_name"), "second-server-name") + .put(stringKey("csm.service_namespace"), "namespace-0001") + .put(stringKey("csm.workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.mesh_id"), "mymesh") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.remote_workload_type"), "gcp_kubernetes_engine") + .put(stringKey("csm.remote_workload_project_id"), "31415926") + .put(stringKey("csm.remote_workload_location"), "us-central1") + .put(stringKey("csm.remote_workload_cluster_name"), "mycluster1") + .put(stringKey("csm.remote_workload_namespace_name"), "namespace-aeiou") + .put(stringKey("csm.remote_workload_name"), "best-client") + .put(stringKey("csm.workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.mesh_id"), "meshhh") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void gceExchange() throws Exception { + // Purposefully use a different project ID in the bootstrap than the resource, as the mesh could + // be in a different project than the running account. + String clientBootstrap = "{\"node\": {" + + "\"id\": \"projects/12/networks/mesh:mymesh/nodes/a6420022-cbc5-4e10-808c-507e3fc95f2e\"" + + "}}"; + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.region"), "us-central1") + .put(stringKey("cloud.account.id"), "31415926") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", + "CSM_WORKLOAD_NAME", "best-client")::get, + () -> clientBootstrap); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + String serverBootstrap = "{\"node\": {" + + "\"id\": \"projects/34/networks/mesh:meshhh/nodes/4969ef19-24b6-44c0-baf3-86d188ff5967\"" + + "}}"; + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.availability_zone"), "us-east2-c") + .put(stringKey("cloud.region"), "us-east2") + .put(stringKey("cloud.account.id"), "11235813") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", + "CSM_WORKLOAD_NAME", "fast-server")::get, + () -> serverBootstrap); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder<?> channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor() + .intercept(new ProvideFilterMetadataInterceptor(ImmutableMap.<String, Struct>of())); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "11235813") + .put(stringKey("csm.remote_workload_location"), "us-east2-c") + .put(stringKey("csm.remote_workload_name"), "fast-server") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.mesh_id"), "mymesh") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "31415926") + .put(stringKey("csm.remote_workload_location"), "us-central1") + .put(stringKey("csm.remote_workload_name"), "best-client") + .put(stringKey("csm.workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.mesh_id"), "meshhh") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void trailersOnly() throws Exception { + String clientBootstrap = "{\"node\": {" + + "\"id\": \"projects/12/networks/mesh:mymesh/nodes/a6420022-cbc5-4e10-808c-507e3fc95f2e\"" + + "}}"; + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.region"), "us-central1") + .put(stringKey("cloud.account.id"), "31415926") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", + "CSM_WORKLOAD_NAME", "best-client")::get, + () -> clientBootstrap); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + String serverBootstrap = "{\"node\": {" + + "\"id\": \"projects/34/networks/mesh:meshhh/nodes/4969ef19-24b6-44c0-baf3-86d188ff5967\"" + + "}}"; + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.availability_zone"), "us-east2-c") + .put(stringKey("cloud.region"), "us-east2") + .put(stringKey("cloud.account.id"), "11235813") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", + "CSM_WORKLOAD_NAME", "fast-server")::get, + () -> serverBootstrap); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + // Trailers-only + serverBuilder.addService(voidService(Status.INVALID_ARGUMENT)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder<?> channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + assertThrows(StatusRuntimeException.class, () -> + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null)); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "INVALID_ARGUMENT") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "11235813") + .put(stringKey("csm.remote_workload_location"), "us-east2-c") + .put(stringKey("csm.remote_workload_name"), "fast-server") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.mesh_id"), "mymesh") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "INVALID_ARGUMENT") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "31415926") + .put(stringKey("csm.remote_workload_location"), "us-central1") + .put(stringKey("csm.remote_workload_name"), "best-client") + .put(stringKey("csm.workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.mesh_id"), "meshhh") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + private void register(NameResolverProvider provider) { + assertThat(fakeNameResolverProvider).isNull(); + fakeNameResolverProvider = provider; + NameResolverRegistry.getDefaultRegistry().register(provider); + } + + private static ServerServiceDefinition voidService(Status status) { + return ServerServiceDefinition.builder(TestMethodDescriptors.voidMethod().getServiceName()) + .addMethod(TestMethodDescriptors.voidMethod(), (call, headers) -> { + if (status.isOk()) { + call.sendHeaders(new Metadata()); + call.sendMessage(null); + } + call.close(status, new Metadata()); + return new ServerCall.Listener<Void>() {}; + }) + .build(); + } + + private void assertMetrics( + Attributes preexistingClientAttributes, + Attributes preexistingClientEndAttributes, + Attributes newClientAttributes, + Attributes preexistingServerAttributes, + Attributes newServerAttributes) { + OpenTelemetryAssertions.assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.started") + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.sent_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.rcvd_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.call.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientEndAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.started") + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(preexistingServerAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newServerAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.sent_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newServerAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.rcvd_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newServerAttributes)))); + } + + private static class ProvideFilterMetadataInterceptor implements ClientInterceptor { + private final ImmutableMap<String, Struct> filterMetadata; + + public ProvideFilterMetadataInterceptor(ImmutableMap<String, Struct> filterMetadata) { + this.filterMetadata = filterMetadata; + } + + @Override + public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( + MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { + callOptions.getOption(ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER) + .accept(filterMetadata); + return next.newCall(method, callOptions); + } + } +} diff --git a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java new file mode 100644 index 000000000..20665e502 --- /dev/null +++ b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import static com.google.common.truth.Truth.assertThat; +import static io.opentelemetry.api.common.AttributeKey.stringKey; + +import com.google.common.collect.ImmutableMap; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.Metadata; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link MetadataExchanger}. */ +@RunWith(JUnit4.class) +public final class MetadataExchangerTest { + @Test + public void getMeshId_findsMeshId() { + assertThat(MetadataExchanger.getMeshId(() -> + "{\"node\":{\"id\":\"projects/12/networks/mesh:mine/nodes/uu-id\"}}")) + .isEqualTo("mine"); + assertThat(MetadataExchanger.getMeshId(() -> + "{\"node\":{\"id\":\"projects/1234567890/networks/mesh:mine/nodes/uu-id\", " + + "\"unknown\": \"\"}, \"unknown\": \"\"}")) + .isEqualTo("mine"); + } + + @Test + public void getMeshId_returnsNullOnBadMeshId() { + assertThat(MetadataExchanger.getMeshId( + () -> "[\"node\"]")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":[\"id\"]}}")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":{\"id\":[\"projects/12/networks/mesh:mine/nodes/uu-id\"]}}")) + .isNull(); + + assertThat(MetadataExchanger.getMeshId( + () -> "{\"NODE\":{\"id\":\"projects/12/networks/mesh:mine/nodes/uu-id\"}}")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":{\"ID\":\"projects/12/networks/mesh:mine/nodes/uu-id\"}}")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":{\"id\":\"projects/12/networks/mesh:mine\"}}")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":{\"id\":\"PROJECTS/12/networks/mesh:mine/nodes/uu-id\"}}")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":{\"id\":\"projects/12/NETWORKS/mesh:mine/nodes/uu-id\"}}")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":{\"id\":\"projects/12/networks/MESH:mine/nodes/uu-id\"}}")) + .isNull(); + assertThat(MetadataExchanger.getMeshId( + () -> "{\"node\":{\"id\":\"projects/12/networks/mesh:mine/NODES/uu-id\"}}")) + .isNull(); + } + + @Test + public void enablePluginForChannel_matches() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + assertThat(exchanger.enablePluginForChannel("xds:///testing")).isTrue(); + assertThat(exchanger.enablePluginForChannel("xds:/testing")).isTrue(); + assertThat(exchanger.enablePluginForChannel( + "xds://traffic-director-global.xds.googleapis.com/testing:123")).isTrue(); + } + + @Test + public void enablePluginForChannel_doesNotMatch() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + assertThat(exchanger.enablePluginForChannel("dns:///localhost")).isFalse(); + assertThat(exchanger.enablePluginForChannel("xds:///[]")).isFalse(); + assertThat(exchanger.enablePluginForChannel("xds://my-xds-server/testing")).isFalse(); + } + + @Test + public void addLabels_receivedWrongType() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER), + BaseEncoding.base64().encode(Struct.newBuilder() + .putFields("type", Value.newBuilder().setNumberValue(1).build()) + .build() + .toByteArray())); + AttributesBuilder builder = Attributes.builder(); + exchanger.newServerStreamPlugin(metadata).addLabels(builder); + + assertThat(builder.build()).isEqualTo(Attributes.builder() + .put(stringKey("csm.mesh_id"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .build()); + } + + @Test + public void addLabelsFromExchange_unknownGcpType() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null, () -> ""); + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER), + BaseEncoding.base64().encode(Struct.newBuilder() + .putFields("type", Value.newBuilder().setStringValue("gcp_surprise").build()) + .putFields("canonical_service", Value.newBuilder().setStringValue("myservice1").build()) + .build() + .toByteArray())); + AttributesBuilder builder = Attributes.builder(); + exchanger.newServerStreamPlugin(metadata).addLabels(builder); + + assertThat(builder.build()).isEqualTo(Attributes.builder() + .put(stringKey("csm.mesh_id"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "gcp_surprise") + .put(stringKey("csm.remote_workload_canonical_service"), "myservice1") + .build()); + } + + @Test + public void addMetadata_k8s() throws Exception { + MetadataExchanger exchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .put(stringKey("k8s.namespace.name"), "mynamespace1") + .put(stringKey("k8s.cluster.name"), "mycluster1") + .put(stringKey("cloud.availability_zone"), "myzone1") + .put(stringKey("cloud.account.id"), "0001") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "myservice1", + "CSM_WORKLOAD_NAME", "myworkload1")::get, + () -> ""); + Metadata metadata = new Metadata(); + exchanger.newClientCallPlugin().addMetadata(metadata); + + Struct peer = Struct.parseFrom(BaseEncoding.base64().decode(metadata.get( + Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER)))); + assertThat(peer).isEqualTo( + Struct.newBuilder() + .putFields("type", Value.newBuilder().setStringValue("gcp_kubernetes_engine").build()) + .putFields("canonical_service", Value.newBuilder().setStringValue("myservice1").build()) + .putFields("workload_name", Value.newBuilder().setStringValue("myworkload1").build()) + .putFields("namespace_name", Value.newBuilder().setStringValue("mynamespace1").build()) + .putFields("cluster_name", Value.newBuilder().setStringValue("mycluster1").build()) + .putFields("location", Value.newBuilder().setStringValue("myzone1").build()) + .putFields("project_id", Value.newBuilder().setStringValue("0001").build()) + .build()); + } + + @Test + public void addMetadata_gce() throws Exception { + MetadataExchanger exchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.availability_zone"), "myzone1") + .put(stringKey("cloud.account.id"), "0001") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "myservice1", + "CSM_WORKLOAD_NAME", "myworkload1")::get, + () -> ""); + Metadata metadata = new Metadata(); + exchanger.newClientCallPlugin().addMetadata(metadata); + + Struct peer = Struct.parseFrom(BaseEncoding.base64().decode(metadata.get( + Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER)))); + assertThat(peer).isEqualTo( + Struct.newBuilder() + .putFields("type", Value.newBuilder().setStringValue("gcp_compute_engine").build()) + .putFields("canonical_service", Value.newBuilder().setStringValue("myservice1").build()) + .putFields("workload_name", Value.newBuilder().setStringValue("myworkload1").build()) + .putFields("location", Value.newBuilder().setStringValue("myzone1").build()) + .putFields("project_id", Value.newBuilder().setStringValue("0001").build()) + .build()); + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 1d2759c1b..dd323cecd 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -18,6 +18,7 @@ androidx-test-ext-junit = "androidx.test.ext:junit:1.1.5" androidx-test-rules = "androidx.test:rules:1.5.0" animalsniffer = "org.codehaus.mojo:animal-sniffer:1.23" animalsniffer-annotations = "org.codehaus.mojo:animal-sniffer-annotations:1.23" +assertj-core = "org.assertj:assertj-core:3.24.2" auto-value = "com.google.auto.value:auto-value:1.10.4" auto-value-annotations = "com.google.auto.value:auto-value-annotations:1.10.4" checkstyle = "com.puppycrawl.tools:checkstyle:10.12.5" @@ -44,11 +45,11 @@ j2objc-annotations = " com.google.j2objc:j2objc-annotations:2.8" jakarta-servlet-api = "jakarta.servlet:jakarta.servlet-api:5.0.0" javax-annotation = "org.apache.tomcat:annotations-api:6.0.53" javax-servlet-api = "javax.servlet:javax.servlet-api:4.0.1" -jetty-client = "org.eclipse.jetty:jetty-client:10.0.7" -jetty-http2-server = "org.eclipse.jetty.http2:http2-server:11.0.7" -jetty-http2-server10 = "org.eclipse.jetty.http2:http2-server:10.0.7" -jetty-servlet = "org.eclipse.jetty:jetty-servlet:11.0.7" -jetty-servlet10 = "org.eclipse.jetty:jetty-servlet:10.0.7" +jetty-client = "org.eclipse.jetty:jetty-client:10.0.20" +jetty-http2-server = "org.eclipse.jetty.http2:http2-server:11.0.20" +jetty-http2-server10 = "org.eclipse.jetty.http2:http2-server:10.0.20" +jetty-servlet = "org.eclipse.jetty:jetty-servlet:11.0.20" +jetty-servlet10 = "org.eclipse.jetty:jetty-servlet:10.0.20" jsr305 = "com.google.code.findbugs:jsr305:3.0.2" junit = "junit:junit:4.13.2" lincheck = "org.jetbrains.kotlinx:lincheck:2.14.1" @@ -76,6 +77,8 @@ opencensus-exporter-trace-stackdriver = { module = "io.opencensus:opencensus-exp opencensus-impl = { module = "io.opencensus:opencensus-impl", version.ref = "opencensus" } opencensus-proto = "io.opencensus:opencensus-proto:0.2.0" opentelemetry-api = "io.opentelemetry:opentelemetry-api:1.36.0" +opentelemetry-gcp-resources = "io.opentelemetry.contrib:opentelemetry-gcp-resources:1.34.0-alpha" +opentelemetry-sdk-extension-autoconfigure = "io.opentelemetry:opentelemetry-sdk-extension-autoconfigure:1.36.0" opentelemetry-sdk-testing = "io.opentelemetry:opentelemetry-sdk-testing:1.36.0" perfmark-api = "io.perfmark:perfmark-api:0.26.0" protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "protobuf" } @@ -86,11 +89,11 @@ re2j = "com.google.re2j:re2j:1.7" robolectric = "org.robolectric:robolectric:4.11.1" signature-android = "net.sf.androidscents.signature:android-api-level-19:4.4.2_r4" signature-java = "org.codehaus.mojo.signature:java18:1.0" -tomcat-embed-core = "org.apache.tomcat.embed:tomcat-embed-core:10.0.14" -tomcat-embed-core9 = "org.apache.tomcat.embed:tomcat-embed-core:9.0.56" +tomcat-embed-core = "org.apache.tomcat.embed:tomcat-embed-core:10.1.23" +tomcat-embed-core9 = "org.apache.tomcat.embed:tomcat-embed-core:9.0.89" truth = "com.google.truth:truth:1.1.5" -undertow-servlet = "io.undertow:undertow-servlet:2.2.14.Final" -undertow-servlet-jakartaee9 = "io.undertow:undertow-servlet-jakartaee9:2.2.13.Final" +undertow-servlet = "io.undertow:undertow-servlet:2.2.32.Final" +undertow-servlet-jakartaee9 = "io.undertow:undertow-servlet:2.3.13.Final" # Do not update: Pinned to the last version supporting Java 8. # See https://checkstyle.sourceforge.io/releasenotes.html#Release_10.1 diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java index f09171487..ae8ad143d 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -203,21 +203,14 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans } }; } - return new Runnable() { - @Override - @SuppressWarnings("deprecation") - public void run() { - synchronized (InProcessTransport.this) { - Attributes serverTransportAttrs = Attributes.newBuilder() - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, address) - .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address) - .build(); - serverStreamAttributes = serverTransportListener.transportReady(serverTransportAttrs); - attributes = clientTransportListener.filterTransport(attributes); - clientTransportListener.transportReady(); - } - } - }; + Attributes serverTransportAttrs = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, address) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address) + .build(); + serverStreamAttributes = serverTransportListener.transportReady(serverTransportAttrs); + attributes = clientTransportListener.filterTransport(attributes); + clientTransportListener.transportReady(); + return null; } @Override @@ -571,7 +564,7 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans return; } - clientStream.statsTraceCtx.clientInboundHeaders(); + clientStream.statsTraceCtx.clientInboundHeaders(headers); syncContext.executeLater(() -> clientStreamListener.headersRead(headers)); } syncContext.drain(); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/ManagedChannelImplIntegrationTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/ManagedChannelImplIntegrationTest.java new file mode 100644 index 000000000..f09f196d7 --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/ManagedChannelImplIntegrationTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.ManagedChannel; +import io.grpc.ServerInterceptors; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; +import io.grpc.internal.testing.StreamRecorder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.integration.EmptyProtos.Empty; +import io.grpc.testing.integration.Messages.ResponseParameters; +import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; +import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; +import java.util.concurrent.TimeUnit; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for ManagedChannelImpl that use a real transport. */ +@RunWith(JUnit4.class) +public final class ManagedChannelImplIntegrationTest { + private static final String SERVER_NAME = ManagedChannelImplIntegrationTest.class.getName(); + @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + @Test + public void idleWhileRpcInTransport_exitsIdleForNewRpc() throws Exception { + FakeClock fakeClock = new FakeClock(); + grpcCleanup.register(InProcessServerBuilder.forName(SERVER_NAME) + .directExecutor() + .addService( + ServerInterceptors.intercept( + new TestServiceImpl(fakeClock.getScheduledExecutorService()), + TestServiceImpl.interceptors())) + .build() + .start()); + ManagedChannel channel = grpcCleanup.register(InProcessChannelBuilder.forName(SERVER_NAME) + .directExecutor() + .build()); + + TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel); + TestServiceGrpc.TestServiceStub asyncStub = TestServiceGrpc.newStub(channel); + StreamRecorder<StreamingOutputCallResponse> responseObserver = StreamRecorder.create(); + StreamObserver<StreamingOutputCallRequest> requestObserver = + asyncStub.fullDuplexCall(responseObserver); + requestObserver.onNext(StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder() + .setIntervalUs(Integer.MAX_VALUE)) + .build()); + try { + channel.enterIdle(); + assertThat(blockingStub + .withDeadlineAfter(10, TimeUnit.SECONDS) + .emptyCall(Empty.getDefaultInstance())) + .isEqualTo(Empty.getDefaultInstance()); + } finally { + requestObserver.onError(new RuntimeException("cleanup")); + } + } +} diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index c60cb4824..da3e20e9f 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -45,14 +45,11 @@ import io.grpc.util.AdvancedTlsX509TrustManager.Verification; import io.grpc.util.CertificateUtils; import java.io.Closeable; import java.io.File; -import java.io.IOException; import java.net.Socket; import java.security.GeneralSecurityException; -import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; -import java.security.spec.InvalidKeySpecException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -65,13 +62,13 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class AdvancedTlsTest { - public static final String SERVER_0_KEY_FILE = "server0.key"; - public static final String SERVER_0_PEM_FILE = "server0.pem"; - public static final String CLIENT_0_KEY_FILE = "client.key"; - public static final String CLIENT_0_PEM_FILE = "client.pem"; - public static final String CA_PEM_FILE = "ca.pem"; - public static final String SERVER_BAD_KEY_FILE = "badserver.key"; - public static final String SERVER_BAD_PEM_FILE = "badserver.pem"; + private static final String SERVER_0_KEY_FILE = "server0.key"; + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String CLIENT_0_KEY_FILE = "client.key"; + private static final String CLIENT_0_PEM_FILE = "client.pem"; + private static final String CA_PEM_FILE = "ca.pem"; + private static final String SERVER_BAD_KEY_FILE = "badserver.key"; + private static final String SERVER_BAD_PEM_FILE = "badserver.pem"; private ScheduledExecutorService executor; private Server server; @@ -92,7 +89,7 @@ public class AdvancedTlsTest { @Before public void setUp() - throws NoSuchAlgorithmException, IOException, CertificateException, InvalidKeySpecException { + throws Exception { executor = Executors.newSingleThreadScheduledExecutor(); caCertFile = TestUtils.loadCert(CA_PEM_FILE); serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); @@ -285,11 +282,11 @@ public class AdvancedTlsTest { new SslSocketAndEnginePeerVerifier() { @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - Socket socket) throws CertificateException { } + Socket socket) { } @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - SSLEngine engine) throws CertificateException { } + SSLEngine engine) { } }) .build(); serverTrustManager.updateTrustCredentials(caCert); @@ -310,11 +307,11 @@ public class AdvancedTlsTest { new SslSocketAndEnginePeerVerifier() { @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - Socket socket) throws CertificateException { } + Socket socket) { } @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - SSLEngine engine) throws CertificateException { } + SSLEngine engine) { } }) .build(); clientTrustManager.updateTrustCredentials(caCert); @@ -419,7 +416,7 @@ public class AdvancedTlsTest { } @Test - public void onFileReloadingKeyManagerBadInitialContentTest() throws Exception { + public void onFileReloadingKeyManagerBadInitialContentTest() { AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); // We swap the order of key and certificates to intentionally create an exception. assertThrows(GeneralSecurityException.class, @@ -439,7 +436,7 @@ public class AdvancedTlsTest { } @Test - public void keyManagerAliasesTest() throws Exception { + public void keyManagerAliasesTest() { AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager(); assertArrayEquals( new String[] {"default"}, km.getClientAliases("", null)); diff --git a/opentelemetry/build.gradle b/opentelemetry/build.gradle index 316d85298..509960e5d 100644 --- a/opentelemetry/build.gradle +++ b/opentelemetry/build.gradle @@ -17,7 +17,7 @@ dependencies { testImplementation testFixtures(project(':grpc-core')), project(':grpc-testing'), libraries.opentelemetry.sdk.testing, - "org.assertj:assertj-core:3.24.2" + libraries.assertj.core // opentelemetry.sdk.testing uses compileOnly for assertj annotationProcessor libraries.auto.value diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java index 69457fa43..68989d80e 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java @@ -84,8 +84,8 @@ public final class GrpcOpenTelemetry { this.disableDefault = builder.disableAll; this.resource = createMetricInstruments(meter, enableMetrics, disableDefault); this.optionalLabels = ImmutableList.copyOf(builder.optionalLabels); - this.openTelemetryMetricsModule = - new OpenTelemetryMetricsModule(STOPWATCH_SUPPLIER, resource, optionalLabels); + this.openTelemetryMetricsModule = new OpenTelemetryMetricsModule( + STOPWATCH_SUPPLIER, resource, optionalLabels, builder.plugins); this.sink = new OpenTelemetryMetricSink(meter, enableMetrics, disableDefault, optionalLabels); } @@ -272,6 +272,7 @@ public final class GrpcOpenTelemetry { */ public static class Builder { private OpenTelemetry openTelemetrySdk = OpenTelemetry.noop(); + private final List<OpenTelemetryPlugin> plugins = new ArrayList<>(); private final Collection<String> optionalLabels = new ArrayList<>(); private final Map<String, Boolean> enableMetrics = new HashMap<>(); private boolean disableAll; @@ -288,6 +289,11 @@ public final class GrpcOpenTelemetry { return this; } + Builder plugin(OpenTelemetryPlugin plugin) { + plugins.add(checkNotNull(plugin, "plugin")); + return this; + } + /** * Adds optionalLabelKey to all the metrics that can provide value for the * optionalLabelKey. diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java new file mode 100644 index 000000000..5d5543ddd --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import io.grpc.Internal; + +/** + * Internal accessor for {@link GrpcOpenTelemetry}. + */ +@Internal +public final class InternalGrpcOpenTelemetry { + private InternalGrpcOpenTelemetry() {} + + public static void builderPlugin( + GrpcOpenTelemetry.Builder builder, InternalOpenTelemetryPlugin plugin) { + builder.plugin(plugin); + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalOpenTelemetryPlugin.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalOpenTelemetryPlugin.java new file mode 100644 index 000000000..38275506e --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalOpenTelemetryPlugin.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import io.grpc.Internal; + +/** + * Accessors for making plugins. + */ +@Internal +public interface InternalOpenTelemetryPlugin extends OpenTelemetryPlugin { + @Override + ClientCallPlugin newClientCallPlugin(); + + interface ClientCallPlugin extends OpenTelemetryPlugin.ClientCallPlugin { + @Override + ClientStreamPlugin newClientStreamPlugin(); + } + + interface ClientStreamPlugin extends OpenTelemetryPlugin.ClientStreamPlugin { + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java index c93e5a10a..f631da59d 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java @@ -25,6 +25,7 @@ import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.TARGET_KEY; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.grpc.CallOptions; import io.grpc.Channel; @@ -42,7 +43,10 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StreamTracer; import io.opentelemetry.api.common.AttributesBuilder; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLong; @@ -86,12 +90,15 @@ final class OpenTelemetryMetricsModule { private final OpenTelemetryMetricsResource resource; private final Supplier<Stopwatch> stopwatchSupplier; private final boolean localityEnabled; + private final ImmutableList<OpenTelemetryPlugin> plugins; OpenTelemetryMetricsModule(Supplier<Stopwatch> stopwatchSupplier, - OpenTelemetryMetricsResource resource, Collection<String> optionalLabels) { + OpenTelemetryMetricsResource resource, Collection<String> optionalLabels, + List<OpenTelemetryPlugin> plugins) { this.resource = checkNotNull(resource, "resource"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); this.localityEnabled = optionalLabels.contains(LOCALITY_LABEL_NAME); + this.plugins = ImmutableList.copyOf(plugins); } /** @@ -105,7 +112,14 @@ final class OpenTelemetryMetricsModule { * Returns the client interceptor that facilitates OpenTelemetry metrics reporting. */ ClientInterceptor getClientInterceptor(String target) { - return new MetricsClientInterceptor(target); + ImmutableList.Builder<OpenTelemetryPlugin> pluginBuilder = + ImmutableList.builderWithExpectedSize(plugins.size()); + for (OpenTelemetryPlugin plugin : plugins) { + if (plugin.enablePluginForChannel(target)) { + pluginBuilder.add(plugin); + } + } + return new MetricsClientInterceptor(target, pluginBuilder.build()); } static String recordMethodName(String fullMethodName, boolean isGeneratedMethod) { @@ -144,6 +158,7 @@ final class OpenTelemetryMetricsModule { final StreamInfo info; final String target; final String fullMethodName; + final List<OpenTelemetryPlugin.ClientStreamPlugin> streamPlugins; volatile long outboundWireSize; volatile long inboundWireSize; volatile String locality; @@ -151,16 +166,25 @@ final class OpenTelemetryMetricsModule { Code statusCode; ClientTracer(CallAttemptsTracerFactory attemptsState, OpenTelemetryMetricsModule module, - StreamInfo info, String target, String fullMethodName) { + StreamInfo info, String target, String fullMethodName, + List<OpenTelemetryPlugin.ClientStreamPlugin> streamPlugins) { this.attemptsState = attemptsState; this.module = module; this.info = info; this.target = target; this.fullMethodName = fullMethodName; + this.streamPlugins = streamPlugins; this.stopwatch = module.stopwatchSupplier.get().start(); } @Override + public void inboundHeaders(Metadata headers) { + for (OpenTelemetryPlugin.ClientStreamPlugin plugin : streamPlugins) { + plugin.inboundHeaders(headers); + } + } + + @Override @SuppressWarnings("NonAtomicVolatileUpdate") public void outboundWireSize(long bytes) { if (outboundWireSizeUpdater != null) { @@ -188,6 +212,13 @@ final class OpenTelemetryMetricsModule { } @Override + public void inboundTrailers(Metadata trailers) { + for (OpenTelemetryPlugin.ClientStreamPlugin plugin : streamPlugins) { + plugin.inboundTrailers(trailers); + } + } + + @Override public void streamClosed(Status status) { stopwatch.stop(); attemptNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); @@ -217,6 +248,9 @@ final class OpenTelemetryMetricsModule { } builder.put(LOCALITY_KEY, savedLocality); } + for (OpenTelemetryPlugin.ClientStreamPlugin plugin : streamPlugins) { + plugin.addLabels(builder); + } io.opentelemetry.api.common.Attributes attribute = builder.build(); if (module.resource.clientAttemptDurationCounter() != null ) { @@ -243,6 +277,7 @@ final class OpenTelemetryMetricsModule { @GuardedBy("lock") private boolean callEnded; private final String fullMethodName; + private final List<OpenTelemetryPlugin.ClientCallPlugin> callPlugins; private Status status; private long callLatencyNanos; private final Object lock = new Object(); @@ -253,10 +288,14 @@ final class OpenTelemetryMetricsModule { private boolean finishedCallToBeRecorded; CallAttemptsTracerFactory( - OpenTelemetryMetricsModule module, String target, String fullMethodName) { + OpenTelemetryMetricsModule module, + String target, + String fullMethodName, + List<OpenTelemetryPlugin.ClientCallPlugin> callPlugins) { this.module = checkNotNull(module, "module"); this.target = checkNotNull(target, "target"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); + this.callPlugins = checkNotNull(callPlugins, "callPlugins"); this.attemptStopwatch = module.stopwatchSupplier.get(); this.callStopWatch = module.stopwatchSupplier.get().start(); @@ -295,7 +334,19 @@ final class OpenTelemetryMetricsModule { if (!info.isTransparentRetry()) { attemptsPerCall.incrementAndGet(); } - return new ClientTracer(this, module, info, target, fullMethodName); + return newClientTracer(info); + } + + private ClientTracer newClientTracer(StreamInfo info) { + List<OpenTelemetryPlugin.ClientStreamPlugin> streamPlugins = Collections.emptyList(); + if (!callPlugins.isEmpty()) { + streamPlugins = new ArrayList<>(callPlugins.size()); + for (OpenTelemetryPlugin.ClientCallPlugin plugin : callPlugins) { + streamPlugins.add(plugin.newClientStreamPlugin()); + } + streamPlugins = Collections.unmodifiableList(streamPlugins); + } + return new ClientTracer(this, module, info, target, fullMethodName, streamPlugins); } // Called whenever each attempt is ended. @@ -337,8 +388,7 @@ final class OpenTelemetryMetricsModule { void recordFinishedCall() { if (attemptsPerCall.get() == 0) { - ClientTracer tracer = - new ClientTracer(this, module, null, target, fullMethodName); + ClientTracer tracer = newClientTracer(null); tracer.attemptNanos = attemptStopwatch.elapsed(TimeUnit.NANOSECONDS); tracer.statusCode = status.getCode(); tracer.recordFinishedAttempt(); @@ -390,15 +440,18 @@ final class OpenTelemetryMetricsModule { private final OpenTelemetryMetricsModule module; private final String fullMethodName; + private final List<OpenTelemetryPlugin.ServerStreamPlugin> streamPlugins; private volatile boolean isGeneratedMethod; private volatile int streamClosed; private final Stopwatch stopwatch; private volatile long outboundWireSize; private volatile long inboundWireSize; - ServerTracer(OpenTelemetryMetricsModule module, String fullMethodName) { + ServerTracer(OpenTelemetryMetricsModule module, String fullMethodName, + List<OpenTelemetryPlugin.ServerStreamPlugin> streamPlugins) { this.module = checkNotNull(module, "module"); this.fullMethodName = fullMethodName; + this.streamPlugins = checkNotNull(streamPlugins, "streamPlugins"); this.stopwatch = module.stopwatchSupplier.get().start(); } @@ -458,10 +511,13 @@ final class OpenTelemetryMetricsModule { } stopwatch.stop(); long elapsedTimeNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); - io.opentelemetry.api.common.Attributes attributes = - io.opentelemetry.api.common.Attributes.of( - METHOD_KEY, recordMethodName(fullMethodName, isGeneratedMethod), - STATUS_KEY, status.getCode().toString()); + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, recordMethodName(fullMethodName, isGeneratedMethod)) + .put(STATUS_KEY, status.getCode().toString()); + for (OpenTelemetryPlugin.ServerStreamPlugin plugin : streamPlugins) { + plugin.addLabels(builder); + } + io.opentelemetry.api.common.Attributes attributes = builder.build(); if (module.resource.serverCallDurationCounter() != null) { module.resource.serverCallDurationCounter() @@ -482,32 +538,63 @@ final class OpenTelemetryMetricsModule { final class ServerTracerFactory extends ServerStreamTracer.Factory { @Override public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { - return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName); + final List<OpenTelemetryPlugin.ServerStreamPlugin> streamPlugins; + if (plugins.isEmpty()) { + streamPlugins = Collections.emptyList(); + } else { + List<OpenTelemetryPlugin.ServerStreamPlugin> streamPluginsMutable = + new ArrayList<>(plugins.size()); + for (OpenTelemetryPlugin plugin : plugins) { + streamPluginsMutable.add(plugin.newServerStreamPlugin(headers)); + } + streamPlugins = Collections.unmodifiableList(streamPluginsMutable); + } + return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName, streamPlugins); } } @VisibleForTesting final class MetricsClientInterceptor implements ClientInterceptor { private final String target; + private final ImmutableList<OpenTelemetryPlugin> plugins; - MetricsClientInterceptor(String target) { + MetricsClientInterceptor(String target, ImmutableList<OpenTelemetryPlugin> plugins) { this.target = checkNotNull(target, "target"); + this.plugins = checkNotNull(plugins, "plugins"); } @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { + final List<OpenTelemetryPlugin.ClientCallPlugin> callPlugins; + if (plugins.isEmpty()) { + callPlugins = Collections.emptyList(); + } else { + List<OpenTelemetryPlugin.ClientCallPlugin> callPluginsMutable = + new ArrayList<>(plugins.size()); + for (OpenTelemetryPlugin plugin : plugins) { + callPluginsMutable.add(plugin.newClientCallPlugin()); + } + callPlugins = Collections.unmodifiableList(callPluginsMutable); + for (OpenTelemetryPlugin.ClientCallPlugin plugin : callPlugins) { + callOptions = plugin.filterCallOptions(callOptions); + } + } // Only record method name as an attribute if isSampledToLocalTracing is set to true, // which is true for all generated methods. Otherwise, programatically // created methods result in high cardinality metrics. final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( OpenTelemetryMetricsModule.this, target, - recordMethodName(method.getFullMethodName(), method.isSampledToLocalTracing())); + recordMethodName(method.getFullMethodName(), method.isSampledToLocalTracing()), + callPlugins); ClientCall<ReqT, RespT> call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall<ReqT, RespT>(call) { @Override public void start(Listener<RespT> responseListener, Metadata headers) { + for (OpenTelemetryPlugin.ClientCallPlugin plugin : callPlugins) { + plugin.addMetadata(headers); + } delegate().start( new SimpleForwardingClientCallListener<RespT>(responseListener) { @Override diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryPlugin.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryPlugin.java new file mode 100644 index 000000000..3705b4b65 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryPlugin.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import io.grpc.CallOptions; +import io.grpc.Metadata; +import io.opentelemetry.api.common.AttributesBuilder; + +/** + * Injects behavior into {@link GrpcOpenTelemetry}. + */ +interface OpenTelemetryPlugin { + /** + * Limited ability to disable the plugin based on the target. This only has an effect for + * per-call metrics. + * + * <p>Ideally this method wouldn't exist and it'd be handled by wrapping GrpcOpenTelemetry and + * conditionally delegating to it. But this is needed by CSM until ChannelBuilders have a + * consistent target over their life; currently specifying nameResolverFactory can change the + * target's scheme. + */ + default boolean enablePluginForChannel(String target) { + return true; + } + + ClientCallPlugin newClientCallPlugin(); + + ServerStreamPlugin newServerStreamPlugin(Metadata inboundMetadata); + + interface ClientCallPlugin { + ClientStreamPlugin newClientStreamPlugin(); + + default void addMetadata(Metadata toMetadata) {} + + default CallOptions filterCallOptions(CallOptions options) { + return options; + } + } + + interface ClientStreamPlugin { + default void inboundHeaders(Metadata headers) {} + + default void inboundTrailers(Metadata trailers) {} + + default void addLabels(AttributesBuilder to) {} + } + + interface ServerStreamPlugin { + default void addLabels(AttributesBuilder to) {} + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java index 8193f1191..17a2b26b1 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java @@ -22,6 +22,7 @@ import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.METHOD_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.STATUS_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.TARGET_KEY; import static io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions.assertThat; +import static java.util.Collections.emptyList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -218,7 +219,7 @@ public class OpenTelemetryMetricsModuleTest { enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory(module, target, method.getFullMethodName()); + new CallAttemptsTracerFactory(module, target, method.getFullMethodName(), emptyList()); Metadata headers = new Metadata(); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); @@ -358,7 +359,7 @@ public class OpenTelemetryMetricsModuleTest { OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, - method.getFullMethodName()); + method.getFullMethodName(), emptyList()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -783,7 +784,7 @@ public class OpenTelemetryMetricsModuleTest { OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, - method.getFullMethodName()); + method.getFullMethodName(), emptyList()); fakeClock.forwardTime(3000, MILLISECONDS); Status status = Status.DEADLINE_EXCEEDED.withDescription("5 seconds"); callAttemptsTracerFactory.callEnded(status); @@ -885,9 +886,9 @@ public class OpenTelemetryMetricsModuleTest { OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality")); + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory(module, target, method.getFullMethodName()); + new CallAttemptsTracerFactory(module, target, method.getFullMethodName(), emptyList()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -953,9 +954,9 @@ public class OpenTelemetryMetricsModuleTest { OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality")); + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory(module, target, method.getFullMethodName()); + new CallAttemptsTracerFactory(module, target, method.getFullMethodName(), emptyList()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -1128,7 +1129,7 @@ public class OpenTelemetryMetricsModuleTest { private OpenTelemetryMetricsModule newOpenTelemetryMetricsModule( OpenTelemetryMetricsResource resource) { return new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, Arrays.asList()); + fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList()); } static class CallInfo<ReqT, RespT> extends ServerCallInfo<ReqT, RespT> { diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index c54c93511..0dcffadeb 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -109,6 +109,7 @@ final class CachingRlsLbClient { // LRU cache based on access order (BACKOFF and actual data will be here) @GuardedBy("lock") private final RlsAsyncLruCache linkedHashLruCache; + private final Future<?> periodicCleaner; // any RPC on the fly will cached in this map @GuardedBy("lock") private final Map<RouteLookupRequest, PendingCacheEntry> pendingCallCache = new HashMap<>(); @@ -177,10 +178,10 @@ final class CachingRlsLbClient { new RlsAsyncLruCache( rlsConfig.cacheSizeBytes(), new AutoCleaningEvictionListener(builder.evictionListener), - scheduledExecutorService, ticker, - lock, helper); + periodicCleaner = + scheduledExecutorService.scheduleAtFixedRate(this::periodicClean, 1, 1, TimeUnit.MINUTES); logger = helper.getChannelLogger(); String serverHost = null; try { @@ -267,6 +268,12 @@ final class CachingRlsLbClient { serverName, status.getCode(), status.getDescription())); } + private void periodicClean() { + synchronized (lock) { + linkedHashLruCache.cleanupExpiredEntries(); + } + } + /** Populates async cache entry for new request. */ @GuardedBy("lock") private CachedRouteLookupResponse asyncRlsCall( @@ -349,6 +356,7 @@ final class CachingRlsLbClient { void close() { logger.log(ChannelLogLevel.DEBUG, "CachingRlsLbClient closed"); synchronized (lock) { + periodicCleaner.cancel(false); // all childPolicyWrapper will be returned via AutoCleaningEvictionListener linkedHashLruCache.close(); // TODO(creamsoup) maybe cancel all pending requests @@ -887,15 +895,8 @@ final class CachingRlsLbClient { RlsAsyncLruCache(long maxEstimatedSizeBytes, @Nullable EvictionListener<RouteLookupRequest, CacheEntry> evictionListener, - ScheduledExecutorService ses, Ticker ticker, Object lock, RlsLbHelper helper) { - super( - maxEstimatedSizeBytes, - evictionListener, - 1, - TimeUnit.MINUTES, - ses, - ticker, - lock); + Ticker ticker, RlsLbHelper helper) { + super(maxEstimatedSizeBytes, evictionListener, ticker); this.helper = checkNotNull(helper, "helper"); } diff --git a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java index c1cbb28f2..ba0575efa 100644 --- a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java +++ b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java @@ -29,46 +29,30 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; /** * A LinkedHashLruCache implements least recently used caching where it supports access order lru * cache eviction while allowing entry level expiration time. When the cache reaches max capacity, * LruCache try to remove up to one already expired entries. If it doesn't find any expired entries, - * it will remove based on access order of entry. On top of this, LruCache also proactively removes - * expired entries based on configured time interval. + * it will remove based on access order of entry. To proactively clean up expired entries, call + * {@link #cleanupExpiredEntries()} (e.g., via a recurring timer). */ -@ThreadSafe abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { - private final Object lock; - - @GuardedBy("lock") private final LinkedHashMap<K, SizedValue> delegate; - private final PeriodicCleaner periodicCleaner; private final Ticker ticker; private final EvictionListener<K, SizedValue> evictionListener; - private final AtomicLong estimatedSizeBytes = new AtomicLong(); + private long estimatedSizeBytes; private long estimatedMaxSizeBytes; LinkedHashLruCache( final long estimatedMaxSizeBytes, @Nullable final EvictionListener<K, V> evictionListener, - int cleaningInterval, - TimeUnit cleaningIntervalUnit, - ScheduledExecutorService ses, - final Ticker ticker, - Object lock) { + final Ticker ticker) { checkState(estimatedMaxSizeBytes > 0, "max estimated cache size should be positive"); this.estimatedMaxSizeBytes = estimatedMaxSizeBytes; - this.lock = checkNotNull(lock, "lock"); this.evictionListener = new SizeHandlingEvictionListener(evictionListener); this.ticker = checkNotNull(ticker, "ticker"); delegate = new LinkedHashMap<K, SizedValue>( @@ -78,7 +62,7 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { /* accessOrder= */ true) { @Override protected boolean removeEldestEntry(Map.Entry<K, SizedValue> eldest) { - if (estimatedSizeBytes.get() <= LinkedHashLruCache.this.estimatedMaxSizeBytes) { + if (estimatedSizeBytes <= LinkedHashLruCache.this.estimatedMaxSizeBytes) { return false; } @@ -94,7 +78,6 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { return false; } }; - periodicCleaner = new PeriodicCleaner(ses, cleaningInterval, cleaningIntervalUnit).start(); } /** @@ -124,16 +107,14 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { /** Updates size for given key if entry exists. It is useful if the cache value is mutated. */ public void updateEntrySize(K key) { - synchronized (lock) { - SizedValue entry = readInternal(key); - if (entry == null) { - return; - } - int prevSize = entry.size; - int newSize = estimateSizeOf(key, entry.value); - entry.size = newSize; - estimatedSizeBytes.addAndGet(newSize - prevSize); + SizedValue entry = readInternal(key); + if (entry == null) { + return; } + int prevSize = entry.size; + int newSize = estimateSizeOf(key, entry.value); + entry.size = newSize; + estimatedSizeBytes += newSize - prevSize; } /** @@ -141,7 +122,7 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { * #estimateSizeOf(java.lang.Object, java.lang.Object)}. */ public long estimatedSizeBytes() { - return estimatedSizeBytes.get(); + return estimatedSizeBytes; } @Override @@ -151,12 +132,10 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { checkNotNull(value, "value"); SizedValue existing; int size = estimateSizeOf(key, value); - synchronized (lock) { - estimatedSizeBytes.addAndGet(size); - existing = delegate.put(key, new SizedValue(size, value)); - if (existing != null) { - evictionListener.onEviction(key, existing, EvictionType.REPLACED); - } + estimatedSizeBytes += size; + existing = delegate.put(key, new SizedValue(size, value)); + if (existing != null) { + evictionListener.onEviction(key, existing, EvictionType.REPLACED); } return existing == null ? null : existing.value; } @@ -176,13 +155,11 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { @CheckReturnValue private SizedValue readInternal(K key) { checkNotNull(key, "key"); - synchronized (lock) { - SizedValue existing = delegate.get(key); - if (existing != null && isExpired(key, existing.value, ticker.read())) { - return null; - } - return existing; + SizedValue existing = delegate.get(key); + if (existing != null && isExpired(key, existing.value, ticker.read())) { + return null; } + return existing; } @Override @@ -195,26 +172,22 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { private V invalidate(K key, EvictionType cause) { checkNotNull(key, "key"); checkNotNull(cause, "cause"); - synchronized (lock) { - SizedValue existing = delegate.remove(key); - if (existing != null) { - evictionListener.onEviction(key, existing, cause); - } - return existing == null ? null : existing.value; + SizedValue existing = delegate.remove(key); + if (existing != null) { + evictionListener.onEviction(key, existing, cause); } + return existing == null ? null : existing.value; } @Override public final void invalidateAll() { - synchronized (lock) { - Iterator<Map.Entry<K, SizedValue>> iterator = delegate.entrySet().iterator(); - while (iterator.hasNext()) { - Map.Entry<K, SizedValue> entry = iterator.next(); - if (entry.getValue() != null) { - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT); - } - iterator.remove(); + Iterator<Map.Entry<K, SizedValue>> iterator = delegate.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry<K, SizedValue> entry = iterator.next(); + if (entry.getValue() != null) { + evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT); } + iterator.remove(); } } @@ -227,13 +200,11 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { /** Returns shallow copied values in the cache. */ public final List<V> values() { - synchronized (lock) { - List<V> list = new ArrayList<>(delegate.size()); - for (SizedValue value : delegate.values()) { - list.add(value.value); - } - return Collections.unmodifiableList(list); + List<V> list = new ArrayList<>(delegate.size()); + for (SizedValue value : delegate.values()) { + list.add(value.value); } + return Collections.unmodifiableList(list); } /** @@ -243,27 +214,25 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { */ protected final boolean fitToLimit() { boolean removedAnyUnexpired = false; - synchronized (lock) { - if (estimatedSizeBytes.get() <= estimatedMaxSizeBytes) { - // new size is larger no need to do cleanup - return false; - } - // cleanup expired entries - long now = ticker.read(); - cleanupExpiredEntries(now); - - // cleanup eldest entry until new size limit - Iterator<Map.Entry<K, SizedValue>> lruIter = delegate.entrySet().iterator(); - while (lruIter.hasNext() && estimatedMaxSizeBytes < this.estimatedSizeBytes.get()) { - Map.Entry<K, SizedValue> entry = lruIter.next(); - if (!shouldInvalidateEldestEntry(entry.getKey(), entry.getValue().value, now)) { - break; // Violates some constraint like minimum age so stop our cleanup - } - lruIter.remove(); - // eviction listener will update the estimatedSizeBytes - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.SIZE); - removedAnyUnexpired = true; + if (estimatedSizeBytes <= estimatedMaxSizeBytes) { + // new size is larger no need to do cleanup + return false; + } + // cleanup expired entries + long now = ticker.read(); + cleanupExpiredEntries(now); + + // cleanup eldest entry until new size limit + Iterator<Map.Entry<K, SizedValue>> lruIter = delegate.entrySet().iterator(); + while (lruIter.hasNext() && estimatedMaxSizeBytes < this.estimatedSizeBytes) { + Map.Entry<K, SizedValue> entry = lruIter.next(); + if (!shouldInvalidateEldestEntry(entry.getKey(), entry.getValue().value, now)) { + break; // Violates some constraint like minimum age so stop our cleanup } + lruIter.remove(); + // eviction listener will update the estimatedSizeBytes + evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.SIZE); + removedAnyUnexpired = true; } return removedAnyUnexpired; } @@ -273,18 +242,19 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { * removing expired entries and removing oldest entries by LRU order. */ public final void resize(long newSizeBytes) { - synchronized (lock) { - this.estimatedMaxSizeBytes = newSizeBytes; - fitToLimit(); - } + this.estimatedMaxSizeBytes = newSizeBytes; + fitToLimit(); } @Override @CheckReturnValue public final int estimatedSize() { - synchronized (lock) { - return delegate.size(); - } + return delegate.size(); + } + + /** Returns {@code true} if any entries were removed. */ + public final boolean cleanupExpiredEntries() { + return cleanupExpiredEntries(ticker.read()); } private boolean cleanupExpiredEntries(long now) { @@ -295,16 +265,14 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { private boolean cleanupExpiredEntries(int maxExpiredEntries, long now) { checkArgument(maxExpiredEntries > 0, "maxExpiredEntries must be positive"); boolean removedAny = false; - synchronized (lock) { - Iterator<Map.Entry<K, SizedValue>> lruIter = delegate.entrySet().iterator(); - while (lruIter.hasNext() && maxExpiredEntries > 0) { - Map.Entry<K, SizedValue> entry = lruIter.next(); - if (isExpired(entry.getKey(), entry.getValue().value, now)) { - lruIter.remove(); - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPIRED); - removedAny = true; - maxExpiredEntries--; - } + Iterator<Map.Entry<K, SizedValue>> lruIter = delegate.entrySet().iterator(); + while (lruIter.hasNext() && maxExpiredEntries > 0) { + Map.Entry<K, SizedValue> entry = lruIter.next(); + if (isExpired(entry.getKey(), entry.getValue().value, now)) { + lruIter.remove(); + evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPIRED); + removedAny = true; + maxExpiredEntries--; } } return removedAny; @@ -312,48 +280,7 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { @Override public final void close() { - synchronized (lock) { - periodicCleaner.stop(); - invalidateAll(); - } - } - - /** Periodically cleans up the AsyncRequestCache. */ - private final class PeriodicCleaner { - - private final ScheduledExecutorService ses; - private final int interval; - private final TimeUnit intervalUnit; - private ScheduledFuture<?> scheduledFuture; - - PeriodicCleaner(ScheduledExecutorService ses, int interval, TimeUnit intervalUnit) { - this.ses = checkNotNull(ses, "ses"); - checkState(interval > 0, "interval must be positive"); - this.interval = interval; - this.intervalUnit = checkNotNull(intervalUnit, "intervalUnit"); - } - - PeriodicCleaner start() { - checkState(scheduledFuture == null, "cleaning task can be started only once"); - this.scheduledFuture = - ses.scheduleAtFixedRate(new CleaningTask(), interval, interval, intervalUnit); - return this; - } - - void stop() { - if (scheduledFuture != null) { - scheduledFuture.cancel(false); - scheduledFuture = null; - } - } - - private class CleaningTask implements Runnable { - - @Override - public void run() { - cleanupExpiredEntries(ticker.read()); - } - } + invalidateAll(); } /** A {@link EvictionListener} keeps track of size. */ @@ -367,7 +294,7 @@ abstract class LinkedHashLruCache<K, V> implements LruCache<K, V> { @Override public void onEviction(K key, SizedValue value, EvictionType cause) { - estimatedSizeBytes.addAndGet(-1L * value.size); + estimatedSizeBytes -= value.size; if (delegate != null) { delegate.onEviction(key, value.value, cause); } diff --git a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java index a31f58f53..f38b28d84 100644 --- a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java +++ b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java @@ -56,14 +56,10 @@ public class LinkedHashLruCacheTest { this.cache = new LinkedHashLruCache<Integer, Entry>( MAX_SIZE, evictionListener, - 10, - TimeUnit.NANOSECONDS, - fakeClock.getScheduledExecutorService(), - fakeClock.getTicker(), - new Object()) { + fakeClock.getTicker()) { @Override protected boolean isExpired(Integer key, Entry value, long nowNanos) { - return value.expireTime <= nowNanos; + return value.expireTime - nowNanos <= 0; } @Override @@ -107,9 +103,11 @@ public class LinkedHashLruCacheTest { cache.cache(1, survivor); fakeClock.forwardTime(10, TimeUnit.NANOSECONDS); + cache.cleanupExpiredEntries(); verify(evictionListener).onEviction(0, toBeEvicted, EvictionType.EXPIRED); fakeClock.forwardTime(10, TimeUnit.NANOSECONDS); + cache.cleanupExpiredEntries(); verify(evictionListener).onEviction(1, survivor, EvictionType.EXPIRED); } @@ -160,6 +158,7 @@ public class LinkedHashLruCacheTest { assertThat(cache.estimatedSize()).isEqualTo(MAX_SIZE); fakeClock.forwardTime(1, TimeUnit.MINUTES); + cache.cleanupExpiredEntries(); assertThat(cache.read(MAX_SIZE)).isNull(); assertThat(cache.estimatedSize()).isEqualTo(MAX_SIZE - 1); verify(evictionListener).onEviction(eq(MAX_SIZE), any(Entry.class), eq(EvictionType.EXPIRED)); diff --git a/servlet/build.gradle b/servlet/build.gradle index fe5914f51..b6f8f6f0c 100644 --- a/servlet/build.gradle +++ b/servlet/build.gradle @@ -37,15 +37,14 @@ dependencies { compileOnly libraries.javax.servlet.api, libraries.javax.annotation // java 9, 10 needs it - implementation project(':grpc-util'), - project(':grpc-core'), + implementation project(':grpc-core'), libraries.guava testImplementation libraries.javax.servlet.api threadingTestImplementation project(':grpc-servlet'), libraries.truth, - libraries.javax.servlet.api, + libraries.javax.servlet.api, libraries.lincheck itImplementation project(':grpc-servlet'), diff --git a/servlet/jakarta/build.gradle b/servlet/jakarta/build.gradle index f548805bd..71de639da 100644 --- a/servlet/jakarta/build.gradle +++ b/servlet/jakarta/build.gradle @@ -7,16 +7,7 @@ description = "gRPC: Jakarta Servlet" // Set up classpaths and source directories for different servlet tests sourceSets { - undertowTest { - java { - include '**/Undertow*.java' - } - } - tomcatTest { - java { - include '**/Tomcat*.java' - } - } + // Only run these tests if java 11+ is being used if (JavaVersion.current().isJava11Compatible()) { jettyTest { @@ -24,6 +15,16 @@ sourceSets { include '**/Jetty*.java' } } + tomcatTest { + java { + include '**/Tomcat*.java' + } + } + undertowTest { + java { + include '**/Undertow*.java' + } + } } } @@ -56,12 +57,11 @@ def migrate(String name, String inputDir, SourceSet sourceSet) { migrate('main', '../src/main/java', sourceSets.main) -// Build the set of sourceSets and classpaths to modify, since Jetty 11 requires Java 11 -// and must be skipped -migrate('undertowTest', '../src/undertowTest/java', sourceSets.undertowTest) -migrate('tomcatTest', '../src/tomcatTest/java', sourceSets.tomcatTest) +// Only build sourceSets and classpaths for tests if using Java 11 if (JavaVersion.current().isJava11Compatible()) { migrate('jettyTest', '../src/jettyTest/java', sourceSets.jettyTest) + migrate('tomcatTest', '../src/tomcatTest/java', sourceSets.tomcatTest) + migrate('undertowTest', '../src/undertowTest/java', sourceSets.undertowTest) } // Disable checkstyle for this project, since it consists only of generated code @@ -104,41 +104,36 @@ dependencies { // Set up individual classpaths for each test, to avoid any mismatch, // and ensure they are only used when supported by the current jvm -def undertowTest = tasks.register('undertowTest', Test) { - classpath = sourceSets.undertowTest.runtimeClasspath - testClassesDirs = sourceSets.undertowTest.output.classesDirs -} -def tomcat10Test = tasks.register('tomcat10Test', Test) { - classpath = sourceSets.tomcatTest.runtimeClasspath - testClassesDirs = sourceSets.tomcatTest.output.classesDirs - - // Provide a temporary directory for tomcat to be deleted after test finishes - def tomcatTempDir = "$buildDir/tomcat_catalina_base" - systemProperty 'catalina.base', tomcatTempDir - doLast { - file(tomcatTempDir).deleteDir() - } - - // tomcat-embed-core 10 presently performs illegal reflective access on - // java.io.ObjectStreamClass$Caches.localDescs and sun.rmi.transport.Target.ccl, - // see https://lists.apache.org/thread/s0xr7tk2kfkkxfjps9n7dhh4cypfdhyy - if (JavaVersion.current().isJava9Compatible()) { - jvmArgs += ['--add-opens=java.base/java.io=ALL-UNNAMED', '--add-opens=java.rmi/sun.rmi.transport=ALL-UNNAMED'] - } -} - -tasks.named("check").configure { - dependsOn undertowTest, tomcat10Test -} - -// Only run these tests if java 11+ is being used if (JavaVersion.current().isJava11Compatible()) { def jetty11Test = tasks.register('jetty11Test', Test) { classpath = sourceSets.jettyTest.runtimeClasspath testClassesDirs = sourceSets.jettyTest.output.classesDirs } + def tomcat10Test = tasks.register('tomcat10Test', Test) { + classpath = sourceSets.tomcatTest.runtimeClasspath + testClassesDirs = sourceSets.tomcatTest.output.classesDirs + + // Provide a temporary directory for tomcat to be deleted after test finishes + def tomcatTempDir = "$buildDir/tomcat_catalina_base" + systemProperty 'catalina.base', tomcatTempDir + doLast { + file(tomcatTempDir).deleteDir() + } + } + tasks.named('compileTomcatTestJava') { JavaCompile task -> + task.options.release.set 11 + } + + def undertowTest = tasks.register('undertowTest', Test) { + classpath = sourceSets.undertowTest.runtimeClasspath + testClassesDirs = sourceSets.undertowTest.output.classesDirs + } + tasks.named('compileUndertowTestJava') { JavaCompile task -> + task.options.release.set 11 + } + tasks.named("check").configure { - dependsOn jetty11Test + dependsOn jetty11Test, tomcat10Test, undertowTest } } diff --git a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java index f21754fb6..d1fc90752 100644 --- a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java +++ b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java @@ -252,4 +252,19 @@ public class JettyTransportTest extends AbstractTransportTest { @Ignore("regression since bumping grpc v1.46 to v1.53") @Test public void messageProducerOnlyProducesRequestedMessages() {} + + @Override + @Ignore("https://github.com/jetty/jetty.project/issues/11822") + @Test + public void clientChecksInboundMetadataSize_header() {} + + @Override + @Ignore("https://github.com/jetty/jetty.project/issues/11822") + @Test + public void clientChecksInboundMetadataSize_trailer() {} + + @Override + @Ignore("Not yet investigated, but has been seen for multiple servlet containers") + @Test + public void clientShutdownBeforeStartRunnable() {} } diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java index 262036883..21a1d3db7 100644 --- a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java @@ -274,4 +274,9 @@ public class TomcatTransportTest extends AbstractTransportTest { @Ignore("regression since bumping grpc v1.46 to v1.53") @Test public void messageProducerOnlyProducesRequestedMessages() {} + + @Override + @Ignore("Not yet investigated, but has been seen for multiple servlet containers") + @Test + public void clientShutdownBeforeStartRunnable() {} } diff --git a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java index e14c11985..76dfb0c0f 100644 --- a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java +++ b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java @@ -308,4 +308,9 @@ public class UndertowTransportTest extends AbstractTransportTest { @Ignore("regression since bumping grpc v1.46 to v1.53") @Test public void messageProducerOnlyProducesRequestedMessages() {} + + @Override + @Ignore("Not yet investigated, but has been seen for multiple servlet containers") + @Test + public void clientShutdownBeforeStartRunnable() {} } diff --git a/settings.gradle b/settings.gradle index ae6e395e7..6359db91c 100644 --- a/settings.gradle +++ b/settings.gradle @@ -68,6 +68,7 @@ include ":grpc-xds" include ":grpc-bom" include ":grpc-rls" include ":grpc-authz" +include ":grpc-gcp-csm-observability" include ":grpc-gcp-observability" include ":grpc-gcp-observability:interop" include ":grpc-istio-interop-testing" @@ -102,6 +103,7 @@ project(':grpc-xds').projectDir = "$rootDir/xds" as File project(':grpc-bom').projectDir = "$rootDir/bom" as File project(':grpc-rls').projectDir = "$rootDir/rls" as File project(':grpc-authz').projectDir = "$rootDir/authz" as File +project(':grpc-gcp-csm-observability').projectDir = "$rootDir/gcp-csm-observability" as File project(':grpc-gcp-observability').projectDir = "$rootDir/gcp-observability" as File project(':grpc-gcp-observability:interop').projectDir = "$rootDir/gcp-observability/interop" as File project(':grpc-istio-interop-testing').projectDir = "$rootDir/istio-interop-testing" as File diff --git a/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java b/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java index 52bbc8efb..4664dbcc4 100644 --- a/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java +++ b/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java @@ -52,7 +52,7 @@ public final class FakeNameResolverProvider extends NameResolverProvider { @Override protected int priority() { - return 5; // Default + return 10; // High priority } @Override diff --git a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java index 1530834d6..4292107db 100644 --- a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java +++ b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -18,7 +18,6 @@ package io.grpc.util; import static com.google.common.base.Preconditions.checkNotNull; -import io.grpc.ExperimentalApi; import java.io.File; import java.io.FileInputStream; import java.io.IOException; @@ -26,7 +25,6 @@ import java.net.Socket; import java.security.GeneralSecurityException; import java.security.Principal; import java.security.PrivateKey; -import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.concurrent.ScheduledExecutorService; @@ -39,20 +37,15 @@ import javax.net.ssl.X509ExtendedKeyManager; /** * AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure - * advanced TLS features, such as private key and certificate chain reloading, etc. + * advanced TLS features, such as private key and certificate chain reloading. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); - - // The credential information sent to peers to prove our identity. + // Minimum allowed period for refreshing files with credential information. + private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1 ; + // The credential information to be sent to peers to prove our identity. private volatile KeyInfo keyInfo; - /** - * Constructs an AdvancedTlsX509KeyManager. - */ - public AdvancedTlsX509KeyManager() throws CertificateException { } - @Override public PrivateKey getPrivateKey(String alias) { if (alias.equals("default")) { @@ -107,14 +100,17 @@ public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { * @param certs the certificate chain that is going to be used */ public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) { - // TODO(ZhenLian): explore possibilities to do a crypto check here. this.keyInfo = new KeyInfo(checkNotNull(key, "key"), checkNotNull(certs, "certs")); } /** * Schedules a {@code ScheduledExecutorService} to read private key and certificate chains from * the local file paths periodically, and update the cached identity credentials if they are both - * updated. + * updated. You must close the returned Closeable before calling this method again or other update + * methods ({@link AdvancedTlsX509KeyManager#updateIdentityCredentials}, {@link + * AdvancedTlsX509KeyManager#updateIdentityCredentialsFromFile(File, File)}). + * Before scheduling the task, the method synchronously executes {@code readAndUpdate} once. The + * minimum refresh period of 1 minute is enforced. * * @param keyFile the file on disk holding the private key * @param certFile the file on disk holding the certificate chain @@ -131,14 +127,17 @@ public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { throw new GeneralSecurityException( "Files were unmodified before their initial update. Probably a bug."); } + if (checkNotNull(unit, "unit").toMinutes(period) < MINIMUM_REFRESH_PERIOD_IN_MINUTES) { + log.log(Level.FINE, + "Provided refresh period of {0} {1} is too small. Default value of {2} minute(s) " + + "will be used.", new Object[] {period, unit.name(), MINIMUM_REFRESH_PERIOD_IN_MINUTES}); + period = MINIMUM_REFRESH_PERIOD_IN_MINUTES; + unit = TimeUnit.MINUTES; + } final ScheduledFuture<?> future = - executor.scheduleWithFixedDelay( + checkNotNull(executor, "executor").scheduleWithFixedDelay( new LoadFilePathExecution(keyFile, certFile), period, period, unit); - return new Closeable() { - @Override public void close() { - future.cancel(false); - } - }; + return () -> future.cancel(false); } /** @@ -190,8 +189,9 @@ public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { this.currentCertTime = newResult.certTime; } } catch (IOException | GeneralSecurityException e) { - log.log(Level.SEVERE, "Failed refreshing private key and certificate chain from files. " - + "Using previous ones", e); + log.log(Level.SEVERE, e, () -> String.format("Failed refreshing private key and certificate" + + " chain from files. Using previous ones (keyFile lastModified = %s, certFile " + + "lastModified = %s)", keyFile.lastModified(), certFile.lastModified())); } } } @@ -220,8 +220,8 @@ public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { */ private UpdateResult readAndUpdate(File keyFile, File certFile, long oldKeyTime, long oldCertTime) throws IOException, GeneralSecurityException { - long newKeyTime = keyFile.lastModified(); - long newCertTime = certFile.lastModified(); + long newKeyTime = checkNotNull(keyFile, "keyFile").lastModified(); + long newCertTime = checkNotNull(certFile, "certFile").lastModified(); // We only update when both the key and the certs are updated. if (newKeyTime != oldKeyTime && newCertTime != oldCertTime) { FileInputStream keyInputStream = new FileInputStream(keyFile); diff --git a/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java b/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java index b173b3f5e..9c9998571 100644 --- a/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java +++ b/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java @@ -49,6 +49,11 @@ public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { } @Override + public void inboundHeaders(Metadata headers) { + delegate().inboundHeaders(headers); + } + + @Override public void inboundTrailers(Metadata trailers) { delegate().inboundTrailers(trailers); } diff --git a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java new file mode 100644 index 000000000..02813e44f --- /dev/null +++ b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java @@ -0,0 +1,165 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import io.grpc.internal.FakeClock; +import io.grpc.internal.testing.TestUtils; +import io.grpc.testing.TlsTesting; +import java.io.File; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link AdvancedTlsX509KeyManager}. */ +@RunWith(JUnit4.class) +public class AdvancedTlsX509KeyManagerTest { + private static final String SERVER_0_KEY_FILE = "server0.key"; + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String CLIENT_0_KEY_FILE = "client.key"; + private static final String CLIENT_0_PEM_FILE = "client.pem"; + private static final String ALIAS = "default"; + + private ScheduledExecutorService executor; + + private File serverKey0File; + private File serverCert0File; + private File clientKey0File; + private File clientCert0File; + + private PrivateKey serverKey0; + private X509Certificate[] serverCert0; + private PrivateKey clientKey0; + private X509Certificate[] clientCert0; + + @Before + public void setUp() throws Exception { + executor = new FakeClock().getScheduledExecutorService(); + serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); + serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); + clientKey0File = TestUtils.loadCert(CLIENT_0_KEY_FILE); + clientCert0File = TestUtils.loadCert(CLIENT_0_PEM_FILE); + serverKey0 = CertificateUtils.getPrivateKey(TlsTesting.loadCert(SERVER_0_KEY_FILE)); + serverCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_0_PEM_FILE)); + clientKey0 = CertificateUtils.getPrivateKey(TlsTesting.loadCert(CLIENT_0_KEY_FILE)); + clientCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(CLIENT_0_PEM_FILE)); + } + + @Test + public void credentialSetting() throws Exception { + // Overall happy path checking of public API. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + + serverKeyManager.updateIdentityCredentialsFromFile(clientKey0File, clientCert0File); + assertEquals(clientKey0, serverKeyManager.getPrivateKey(ALIAS)); + assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(ALIAS)); + + serverKeyManager.updateIdentityCredentialsFromFile(serverKey0File, serverCert0File, 1, + TimeUnit.MINUTES, executor); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + } + + @Test + public void credentialSettingParameterValidity() throws Exception { + // Checking edge cases of public API parameter setting. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + NullPointerException npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(null, serverCert0)); + assertEquals("key", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(serverKey0, null)); + assertEquals("certs", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentialsFromFile(null, serverCert0File)); + assertEquals("keyFile", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentialsFromFile(serverKey0File, null)); + assertEquals("certFile", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentialsFromFile(serverKey0File, serverCert0File, 1, null, + executor)); + assertEquals("unit", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentialsFromFile(serverKey0File, serverCert0File, 1, + TimeUnit.MINUTES, null)); + assertEquals("executor", npe.getMessage()); + + Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); + TestHandler handler = new TestHandler(); + log.addHandler(handler); + log.setUseParentHandlers(false); + log.setLevel(Level.FINE); + serverKeyManager.updateIdentityCredentialsFromFile(serverKey0File, serverCert0File, -1, + TimeUnit.SECONDS, executor); + log.removeHandler(handler); + for (LogRecord record : handler.getRecords()) { + if (record.getMessage().contains("Default value of ")) { + assertTrue(true); + return; + } + } + fail("Log message related to setting default values not found"); + } + + + private static class TestHandler extends Handler { + private final List<LogRecord> records = new ArrayList<>(); + + @Override + public void publish(LogRecord record) { + records.add(record); + } + + @Override + public void flush() { + } + + @Override + public void close() throws SecurityException { + } + + public List<LogRecord> getRecords() { + return records; + } + } + +} diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index fe73e1886..16ede8ae1 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -176,13 +176,15 @@ final class CdsLoadBalancer2 extends LoadBalancer { clusterState.result.lrsServerInfo(), clusterState.result.maxConcurrentRequests(), clusterState.result.upstreamTlsContext(), + clusterState.result.filterMetadata(), clusterState.result.outlierDetection()); } else { // logical DNS instance = DiscoveryMechanism.forLogicalDns( clusterState.name, clusterState.result.dnsHostName(), clusterState.result.lrsServerInfo(), clusterState.result.maxConcurrentRequests(), - clusterState.result.upstreamTlsContext()); + clusterState.result.upstreamTlsContext(), + clusterState.result.filterMetadata()); } instances.add(instance); } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 45062f28f..e42619d9b 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -21,6 +21,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; @@ -54,6 +56,7 @@ import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; @@ -140,6 +143,7 @@ final class ClusterImplLoadBalancer extends LoadBalancer { childLbHelper.updateDropPolicies(config.dropCategories); childLbHelper.updateMaxConcurrentRequests(config.maxConcurrentRequests); childLbHelper.updateSslContextProviderSupplier(config.tlsContext); + childLbHelper.updateFilterMetadata(config.filterMetadata); childSwitchLb.switchTo(config.childPolicy.getProvider()); childSwitchLb.handleResolvedAddresses( @@ -189,6 +193,7 @@ final class ClusterImplLoadBalancer extends LoadBalancer { private long maxConcurrentRequests = DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS; @Nullable private SslContextProviderSupplier sslContextProviderSupplier; + private Map<String, Struct> filterMetadata = ImmutableMap.of(); @Nullable private final ServerInfo lrsServerInfo; @@ -201,8 +206,8 @@ final class ClusterImplLoadBalancer extends LoadBalancer { public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { currentState = newState; currentPicker = newPicker; - SubchannelPicker picker = - new RequestLimitingSubchannelPicker(newPicker, dropPolicies, maxConcurrentRequests); + SubchannelPicker picker = new RequestLimitingSubchannelPicker( + newPicker, dropPolicies, maxConcurrentRequests, filterMetadata); delegate().updateBalancingState(newState, picker); } @@ -311,20 +316,29 @@ final class ClusterImplLoadBalancer extends LoadBalancer { : null; } + private void updateFilterMetadata(Map<String, Struct> filterMetadata) { + this.filterMetadata = ImmutableMap.copyOf(filterMetadata); + } + private class RequestLimitingSubchannelPicker extends SubchannelPicker { private final SubchannelPicker delegate; private final List<DropOverload> dropPolicies; private final long maxConcurrentRequests; + private final Map<String, Struct> filterMetadata; private RequestLimitingSubchannelPicker(SubchannelPicker delegate, - List<DropOverload> dropPolicies, long maxConcurrentRequests) { + List<DropOverload> dropPolicies, long maxConcurrentRequests, + Map<String, Struct> filterMetadata) { this.delegate = delegate; this.dropPolicies = dropPolicies; this.maxConcurrentRequests = maxConcurrentRequests; + this.filterMetadata = checkNotNull(filterMetadata, "filterMetadata"); } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getCallOptions().getOption(ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER) + .accept(filterMetadata); for (DropOverload dropOverload : dropPolicies) { int rand = random.nextInt(1_000_000); if (rand < dropOverload.dropsPerMillion()) { diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java index ff32779b0..b928f6dae 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java @@ -19,6 +19,9 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; +import io.grpc.CallOptions; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -34,6 +37,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import javax.annotation.Nullable; /** @@ -43,6 +47,11 @@ import javax.annotation.Nullable; */ @Internal public final class ClusterImplLoadBalancerProvider extends LoadBalancerProvider { + /** + * Consumer of filter metadata from the cluster used by the call. Consumer may not modify map. + */ + public static final CallOptions.Key<Consumer<Map<String, Struct>>> FILTER_METADATA_CONSUMER = + CallOptions.Key.createWithDefault("io.grpc.xds.internalFilterMetadataConsumer", (m) -> { }); @Override public boolean isAvailable() { @@ -89,16 +98,18 @@ public final class ClusterImplLoadBalancerProvider extends LoadBalancerProvider final List<DropOverload> dropCategories; // Provides the direct child policy and its config. final PolicySelection childPolicy; + final Map<String, Struct> filterMetadata; ClusterImplConfig(String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, List<DropOverload> dropCategories, PolicySelection childPolicy, - @Nullable UpstreamTlsContext tlsContext) { + @Nullable UpstreamTlsContext tlsContext, Map<String, Struct> filterMetadata) { this.cluster = checkNotNull(cluster, "cluster"); this.edsServiceName = edsServiceName; this.lrsServerInfo = lrsServerInfo; this.maxConcurrentRequests = maxConcurrentRequests; this.tlsContext = tlsContext; + this.filterMetadata = ImmutableMap.copyOf(filterMetadata); this.dropCategories = Collections.unmodifiableList( new ArrayList<>(checkNotNull(dropCategories, "dropCategories"))); this.childPolicy = checkNotNull(childPolicy, "childPolicy"); diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index 881628784..f1fb6c0fb 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -21,6 +21,8 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; @@ -184,10 +186,11 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { if (instance.type == DiscoveryMechanism.Type.EDS) { state = new EdsClusterState(instance.cluster, instance.edsServiceName, instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext, - instance.outlierDetection); + instance.filterMetadata, instance.outlierDetection); } else { // logical DNS state = new LogicalDnsClusterState(instance.cluster, instance.dnsHostName, - instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext); + instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext, + instance.filterMetadata); } clusterStates.put(instance.cluster, state); state.start(); @@ -323,6 +326,7 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { protected final Long maxConcurrentRequests; @Nullable protected final UpstreamTlsContext tlsContext; + protected final Map<String, Struct> filterMetadata; @Nullable protected final OutlierDetection outlierDetection; // Resolution status, may contain most recent error encountered. @@ -337,11 +341,12 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { private ClusterState(String name, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - @Nullable OutlierDetection outlierDetection) { + Map<String, Struct> filterMetadata, @Nullable OutlierDetection outlierDetection) { this.name = name; this.lrsServerInfo = lrsServerInfo; this.maxConcurrentRequests = maxConcurrentRequests; this.tlsContext = tlsContext; + this.filterMetadata = ImmutableMap.copyOf(filterMetadata); this.outlierDetection = outlierDetection; } @@ -360,8 +365,10 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { private EdsClusterState(String name, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, @Nullable OutlierDetection outlierDetection) { - super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, outlierDetection); + @Nullable UpstreamTlsContext tlsContext, Map<String, Struct> filterMetadata, + @Nullable OutlierDetection outlierDetection) { + super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, + outlierDetection); this.edsServiceName = edsServiceName; } @@ -447,8 +454,8 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { Map<String, PriorityChildConfig> priorityChildConfigs = generateEdsBasedPriorityChildConfigs( name, edsServiceName, lrsServerInfo, maxConcurrentRequests, tlsContext, - outlierDetection, endpointLbPolicy, lbRegistry, prioritizedLocalityWeights, - dropOverloads); + filterMetadata, outlierDetection, endpointLbPolicy, lbRegistry, + prioritizedLocalityWeights, dropOverloads); status = Status.OK; resolved = true; result = new ClusterResolutionResult(addresses, priorityChildConfigs, @@ -533,8 +540,8 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { private LogicalDnsClusterState(String name, String dnsHostName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext) { - super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, null); + @Nullable UpstreamTlsContext tlsContext, Map<String, Struct> filterMetadata) { + super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, null); this.dnsHostName = checkNotNull(dnsHostName, "dnsHostName"); nameResolverFactory = checkNotNull(helper.getNameResolverRegistry().asFactory(), "nameResolverFactory"); @@ -623,8 +630,8 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { addresses.add(eag); } PriorityChildConfig priorityChildConfig = generateDnsBasedPriorityChildConfig( - name, lrsServerInfo, maxConcurrentRequests, tlsContext, lbRegistry, - Collections.<DropOverload>emptyList()); + name, lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, + lbRegistry, Collections.<DropOverload>emptyList()); status = Status.OK; resolved = true; result = new ClusterResolutionResult(addresses, priorityName, priorityChildConfig); @@ -707,14 +714,14 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { */ private static PriorityChildConfig generateDnsBasedPriorityChildConfig( String cluster, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, LoadBalancerRegistry lbRegistry, - List<DropOverload> dropOverloads) { + @Nullable UpstreamTlsContext tlsContext, Map<String, Struct> filterMetadata, + LoadBalancerRegistry lbRegistry, List<DropOverload> dropOverloads) { // Override endpoint-level LB policy with pick_first for logical DNS cluster. PolicySelection endpointLbPolicy = new PolicySelection(lbRegistry.getProvider("pick_first"), null); ClusterImplConfig clusterImplConfig = new ClusterImplConfig(cluster, null, lrsServerInfo, maxConcurrentRequests, - dropOverloads, endpointLbPolicy, tlsContext); + dropOverloads, endpointLbPolicy, tlsContext, filterMetadata); LoadBalancerProvider clusterImplLbProvider = lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); PolicySelection clusterImplPolicy = @@ -731,6 +738,7 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { private static Map<String, PriorityChildConfig> generateEdsBasedPriorityChildConfigs( String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, + Map<String, Struct> filterMetadata, @Nullable OutlierDetection outlierDetection, PolicySelection endpointLbPolicy, LoadBalancerRegistry lbRegistry, Map<String, Map<Locality, Integer>> prioritizedLocalityWeights, List<DropOverload> dropOverloads) { @@ -738,7 +746,7 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { for (String priority : prioritizedLocalityWeights.keySet()) { ClusterImplConfig clusterImplConfig = new ClusterImplConfig(cluster, edsServiceName, lrsServerInfo, maxConcurrentRequests, - dropOverloads, endpointLbPolicy, tlsContext); + dropOverloads, endpointLbPolicy, tlsContext, filterMetadata); LoadBalancerProvider clusterImplLbProvider = lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); PolicySelection priorityChildPolicy = diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java index 6488a719a..48ac4155b 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java @@ -19,6 +19,8 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -129,6 +131,7 @@ public final class ClusterResolverLoadBalancerProvider extends LoadBalancerProvi final String dnsHostName; @Nullable final OutlierDetection outlierDetection; + final Map<String, Struct> filterMetadata; enum Type { EDS, @@ -138,7 +141,7 @@ public final class ClusterResolverLoadBalancerProvider extends LoadBalancerProvi private DiscoveryMechanism(String cluster, Type type, @Nullable String edsServiceName, @Nullable String dnsHostName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - @Nullable OutlierDetection outlierDetection) { + Map<String, Struct> filterMetadata, @Nullable OutlierDetection outlierDetection) { this.cluster = checkNotNull(cluster, "cluster"); this.type = checkNotNull(type, "type"); this.edsServiceName = edsServiceName; @@ -146,28 +149,29 @@ public final class ClusterResolverLoadBalancerProvider extends LoadBalancerProvi this.lrsServerInfo = lrsServerInfo; this.maxConcurrentRequests = maxConcurrentRequests; this.tlsContext = tlsContext; + this.filterMetadata = ImmutableMap.copyOf(checkNotNull(filterMetadata, "filterMetadata")); this.outlierDetection = outlierDetection; } static DiscoveryMechanism forEds(String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, + @Nullable UpstreamTlsContext tlsContext, Map<String, Struct> filterMetadata, OutlierDetection outlierDetection) { return new DiscoveryMechanism(cluster, Type.EDS, edsServiceName, null, lrsServerInfo, - maxConcurrentRequests, tlsContext, outlierDetection); + maxConcurrentRequests, tlsContext, filterMetadata, outlierDetection); } static DiscoveryMechanism forLogicalDns(String cluster, String dnsHostName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext) { + @Nullable UpstreamTlsContext tlsContext, Map<String, Struct> filterMetadata) { return new DiscoveryMechanism(cluster, Type.LOGICAL_DNS, null, dnsHostName, - lrsServerInfo, maxConcurrentRequests, tlsContext, null); + lrsServerInfo, maxConcurrentRequests, tlsContext, filterMetadata, null); } @Override public int hashCode() { return Objects.hash(cluster, type, lrsServerInfo, maxConcurrentRequests, tlsContext, - edsServiceName, dnsHostName); + edsServiceName, dnsHostName, filterMetadata, outlierDetection); } @Override @@ -185,7 +189,9 @@ public final class ClusterResolverLoadBalancerProvider extends LoadBalancerProvi && Objects.equals(dnsHostName, that.dnsHostName) && Objects.equals(lrsServerInfo, that.lrsServerInfo) && Objects.equals(maxConcurrentRequests, that.maxConcurrentRequests) - && Objects.equals(tlsContext, that.tlsContext); + && Objects.equals(tlsContext, that.tlsContext) + && Objects.equals(filterMetadata, that.filterMetadata) + && Objects.equals(outlierDetection, that.outlierDetection); } @Override @@ -198,7 +204,10 @@ public final class ClusterResolverLoadBalancerProvider extends LoadBalancerProvi .add("dnsHostName", dnsHostName) .add("lrsServerInfo", lrsServerInfo) // Exclude tlsContext as its string representation is cumbersome. - .add("maxConcurrentRequests", maxConcurrentRequests); + .add("maxConcurrentRequests", maxConcurrentRequests) + .add("filterMetadata", filterMetadata) + // Exclude outlierDetection as its string representation is long. + ; return toStringHelper.toString(); } } diff --git a/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java new file mode 100644 index 000000000..929619c11 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java @@ -0,0 +1,33 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.Internal; +import io.grpc.xds.client.XdsInitializationException; +import java.io.IOException; + +/** + * Internal accessors for GrpcBootstrapperImpl. + */ +@Internal +public final class InternalGrpcBootstrapperImpl { + private InternalGrpcBootstrapperImpl() {} // prevent instantiation + + public static String getJsonContent() throws XdsInitializationException, IOException { + return new GrpcBootstrapperImpl().getJsonContent(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java index 6b6c48972..c6340156d 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableMap; import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; +import com.google.protobuf.Struct; import com.google.protobuf.util.Durations; import io.envoyproxy.envoy.config.cluster.v3.CircuitBreakers.Thresholds; import io.envoyproxy.envoy.config.cluster.v3.Cluster; @@ -160,6 +161,8 @@ class XdsClusterResource extends XdsResourceType<CdsUpdate> { } updateBuilder.lbPolicyConfig(lbPolicyConfig); + updateBuilder.filterMetadata( + ImmutableMap.copyOf(cluster.getMetadata().getFilterMetadataMap())); return updateBuilder.build(); } @@ -559,14 +562,21 @@ class XdsClusterResource extends XdsResourceType<CdsUpdate> { @Nullable abstract OutlierDetection outlierDetection(); - static Builder forAggregate(String clusterName, List<String> prioritizedClusterNames) { - checkNotNull(prioritizedClusterNames, "prioritizedClusterNames"); + abstract ImmutableMap<String, Struct> filterMetadata(); + + private static Builder newBuilder(String clusterName) { return new AutoValue_XdsClusterResource_CdsUpdate.Builder() .clusterName(clusterName) - .clusterType(ClusterType.AGGREGATE) .minRingSize(0) .maxRingSize(0) .choiceCount(0) + .filterMetadata(ImmutableMap.of()); + } + + static Builder forAggregate(String clusterName, List<String> prioritizedClusterNames) { + checkNotNull(prioritizedClusterNames, "prioritizedClusterNames"); + return newBuilder(clusterName) + .clusterType(ClusterType.AGGREGATE) .prioritizedClusterNames(ImmutableList.copyOf(prioritizedClusterNames)); } @@ -574,12 +584,8 @@ class XdsClusterResource extends XdsResourceType<CdsUpdate> { @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext upstreamTlsContext, @Nullable OutlierDetection outlierDetection) { - return new AutoValue_XdsClusterResource_CdsUpdate.Builder() - .clusterName(clusterName) + return newBuilder(clusterName) .clusterType(ClusterType.EDS) - .minRingSize(0) - .maxRingSize(0) - .choiceCount(0) .edsServiceName(edsServiceName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) @@ -591,12 +597,8 @@ class XdsClusterResource extends XdsResourceType<CdsUpdate> { @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext upstreamTlsContext) { - return new AutoValue_XdsClusterResource_CdsUpdate.Builder() - .clusterName(clusterName) + return newBuilder(clusterName) .clusterType(ClusterType.LOGICAL_DNS) - .minRingSize(0) - .maxRingSize(0) - .choiceCount(0) .dnsHostName(dnsHostName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) @@ -685,6 +687,8 @@ class XdsClusterResource extends XdsResourceType<CdsUpdate> { protected abstract Builder outlierDetection(OutlierDetection outlierDetection); + protected abstract Builder filterMetadata(ImmutableMap<String, Struct> filterMetadata); + abstract CdsUpdate build(); } } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index fa08d5edd..cf140a076 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -177,7 +177,8 @@ public class ClusterImplLoadBalancerTest { Object weightedTargetConfig = new Object(); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -202,7 +203,8 @@ public class ClusterImplLoadBalancerTest { ClusterImplConfig configWithWeightedTarget = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), configWithWeightedTarget); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -215,7 +217,8 @@ public class ClusterImplLoadBalancerTest { ClusterImplConfig configWithWrrLocality = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(wrrLocalityProvider, wrrLocalityConfig), null); + new PolicySelection(wrrLocalityProvider, wrrLocalityConfig), null, + Collections.emptyMap()); deliverAddressesAndConfig(Collections.singletonList(endpoint), configWithWrrLocality); childBalancer = Iterables.getOnlyElement(downstreamBalancers); assertThat(childBalancer.name).isEqualTo(XdsLbPolicies.WRR_LOCALITY_POLICY_NAME); @@ -239,7 +242,8 @@ public class ClusterImplLoadBalancerTest { Object weightedTargetConfig = new Object(); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -258,7 +262,8 @@ public class ClusterImplLoadBalancerTest { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -284,7 +289,8 @@ public class ClusterImplLoadBalancerTest { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -368,7 +374,8 @@ public class ClusterImplLoadBalancerTest { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.singletonList(DropOverload.create("throttle", 500_000)), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); when(mockRandom.nextInt(anyInt())).thenReturn(499_999, 999_999, 1_000_000); @@ -397,7 +404,8 @@ public class ClusterImplLoadBalancerTest { // Config update updates drop policies. config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.singletonList(DropOverload.create("lb", 1_000_000)), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.singletonList(endpoint)) @@ -444,7 +452,8 @@ public class ClusterImplLoadBalancerTest { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, maxConcurrentRequests, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); assertThat(downstreamBalancers).hasSize(1); // one leaf balancer @@ -486,7 +495,8 @@ public class ClusterImplLoadBalancerTest { maxConcurrentRequests = 101L; config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, maxConcurrentRequests, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); result = currentPicker.pickSubchannel(pickSubchannelArgs); @@ -532,7 +542,8 @@ public class ClusterImplLoadBalancerTest { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); assertThat(downstreamBalancers).hasSize(1); // one leaf balancer @@ -578,7 +589,8 @@ public class ClusterImplLoadBalancerTest { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); // One locality with two endpoints. EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality); @@ -615,7 +627,8 @@ public class ClusterImplLoadBalancerTest { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext, + Collections.emptyMap()); // One locality with two endpoints. EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality); @@ -638,7 +651,8 @@ public class ClusterImplLoadBalancerTest { // Removes UpstreamTlsContext from the config. config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), null, + Collections.emptyMap()); deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); assertThat(Iterables.getOnlyElement(downstreamBalancers)).isSameInstanceAs(leafBalancer); subchannel = leafBalancer.helper.createSubchannel(args); // creates new connections @@ -652,7 +666,8 @@ public class ClusterImplLoadBalancerTest { CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true); config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.<DropOverload>emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext); + new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext, + Collections.emptyMap()); deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); assertThat(Iterables.getOnlyElement(downstreamBalancers)).isSameInstanceAs(leafBalancer); subchannel = leafBalancer.helper.createSubchannel(args); // creates new connections diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 99b8605b4..dd503592c 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -136,15 +136,16 @@ public class ClusterResolverLoadBalancerTest { FailurePercentageEjection.create(100, 100, 100, 100)); private final DiscoveryMechanism edsDiscoveryMechanism1 = DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, - null); + Collections.emptyMap(), null); private final DiscoveryMechanism edsDiscoveryMechanism2 = DiscoveryMechanism.forEds(CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, tlsContext, - null); + Collections.emptyMap(), null); private final DiscoveryMechanism edsDiscoveryMechanismWithOutlierDetection = DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, - outlierDetection); + Collections.emptyMap(), outlierDetection); private final DiscoveryMechanism logicalDnsDiscoveryMechanism = - DiscoveryMechanism.forLogicalDns(CLUSTER_DNS, DNS_HOST_NAME, LRS_SERVER_INFO, 300L, null); + DiscoveryMechanism.forLogicalDns(CLUSTER_DNS, DNS_HOST_NAME, LRS_SERVER_INFO, 300L, null, + Collections.emptyMap()); private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { diff --git a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java index 0ac024d1e..bf330b100 100644 --- a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java +++ b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java @@ -169,6 +169,9 @@ public class CsdsServiceTest { grpcServerRule.getServiceRegistry() .addService(new CsdsService(new FakeXdsClientPoolFactory(throwingXdsClient))); + // Hack to prevent the interrupted exception from propagating through to the client stub. + grpcServerRule.getChannel().getState(true); + try { ClientStatusResponse response = csdsStub.fetchClientStatus(REQUEST); fail("Should've failed, got response: " + response); |