aboutsummaryrefslogtreecommitdiff
path: root/services/src
diff options
context:
space:
mode:
Diffstat (limited to 'services/src')
-rw-r--r--services/src/main/java/io/grpc/services/BinaryLogProvider.java194
-rw-r--r--services/src/main/java/io/grpc/services/BinaryLogProviderImpl.java1
-rw-r--r--services/src/main/java/io/grpc/services/BinlogHelper.java2
-rw-r--r--services/src/main/java/io/grpc/services/CensusBinaryLogProvider.java1
-rw-r--r--services/src/test/java/io/grpc/services/BinaryLogProviderTest.java439
-rw-r--r--services/src/test/java/io/grpc/services/BinlogHelperTest.java3
-rw-r--r--services/src/test/java/io/grpc/services/CensusBinaryLogProviderTest.java1
7 files changed, 635 insertions, 6 deletions
diff --git a/services/src/main/java/io/grpc/services/BinaryLogProvider.java b/services/src/main/java/io/grpc/services/BinaryLogProvider.java
new file mode 100644
index 000000000..3ca5de655
--- /dev/null
+++ b/services/src/main/java/io/grpc/services/BinaryLogProvider.java
@@ -0,0 +1,194 @@
+/*
+ * Copyright 2017 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.services;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import io.grpc.BinaryLog;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ClientInterceptors;
+import io.grpc.Internal;
+import io.grpc.InternalClientInterceptors;
+import io.grpc.InternalServerInterceptors;
+import io.grpc.ManagedChannel;
+import io.grpc.MethodDescriptor;
+import io.grpc.MethodDescriptor.Marshaller;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.ServerMethodDefinition;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import javax.annotation.Nullable;
+
+// TODO(zpencer): rename class to AbstractBinaryLog
+@Internal
+public abstract class BinaryLogProvider extends BinaryLog {
+ @VisibleForTesting
+ public static final Marshaller<byte[]> BYTEARRAY_MARSHALLER = new ByteArrayMarshaller();
+
+ private final ClientInterceptor binaryLogShim = new BinaryLogShim();
+
+ /**
+ * Wraps a channel to provide binary logging on {@link ClientCall}s as needed.
+ */
+ @Override
+ public final Channel wrapChannel(Channel channel) {
+ return ClientInterceptors.intercept(channel, binaryLogShim);
+ }
+
+ private static MethodDescriptor<byte[], byte[]> toByteBufferMethod(
+ MethodDescriptor<?, ?> method) {
+ return method.toBuilder(BYTEARRAY_MARSHALLER, BYTEARRAY_MARSHALLER).build();
+ }
+
+ /**
+ * Wraps a {@link ServerMethodDefinition} such that it performs binary logging if needed.
+ */
+ @Override
+ public final <ReqT, RespT> ServerMethodDefinition<?, ?> wrapMethodDefinition(
+ ServerMethodDefinition<ReqT, RespT> oMethodDef) {
+ ServerInterceptor binlogInterceptor =
+ getServerInterceptor(oMethodDef.getMethodDescriptor().getFullMethodName());
+ if (binlogInterceptor == null) {
+ return oMethodDef;
+ }
+ MethodDescriptor<byte[], byte[]> binMethod =
+ BinaryLogProvider.toByteBufferMethod(oMethodDef.getMethodDescriptor());
+ ServerMethodDefinition<byte[], byte[]> binDef =
+ InternalServerInterceptors.wrapMethod(oMethodDef, binMethod);
+ ServerCallHandler<byte[], byte[]> binlogHandler =
+ InternalServerInterceptors.interceptCallHandlerCreate(
+ binlogInterceptor, binDef.getServerCallHandler());
+ return ServerMethodDefinition.create(binMethod, binlogHandler);
+ }
+
+ /**
+ * Returns a {@link ServerInterceptor} for binary logging. gRPC is free to cache the interceptor,
+ * so the interceptor must be reusable across calls. At runtime, the request and response
+ * marshallers are always {@code Marshaller<InputStream>}.
+ * Returns {@code null} if this method is not binary logged.
+ */
+ // TODO(zpencer): ensure the interceptor properly handles retries and hedging
+ @Nullable
+ protected abstract ServerInterceptor getServerInterceptor(String fullMethodName);
+
+ /**
+ * Returns a {@link ClientInterceptor} for binary logging. gRPC is free to cache the interceptor,
+ * so the interceptor must be reusable across calls. At runtime, the request and response
+ * marshallers are always {@code Marshaller<InputStream>}.
+ * Returns {@code null} if this method is not binary logged.
+ */
+ // TODO(zpencer): ensure the interceptor properly handles retries and hedging
+ @Nullable
+ protected abstract ClientInterceptor getClientInterceptor(
+ String fullMethodName, CallOptions callOptions);
+
+ @Override
+ public void close() throws IOException {
+ // default impl: noop
+ // TODO(zpencer): make BinaryLogProvider provide a BinaryLog, and this method belongs there
+ }
+
+ // Creating a named class makes debugging easier
+ private static final class ByteArrayMarshaller implements Marshaller<byte[]> {
+ @Override
+ public InputStream stream(byte[] value) {
+ return new ByteArrayInputStream(value);
+ }
+
+ @Override
+ public byte[] parse(InputStream stream) {
+ try {
+ return parseHelper(stream);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private byte[] parseHelper(InputStream stream) throws IOException {
+ try {
+ return IoUtils.toByteArray(stream);
+ } finally {
+ stream.close();
+ }
+ }
+ }
+
+ /**
+ * The pipeline of interceptors is hard coded when the {@link ManagedChannel} is created.
+ * This shim interceptor should always be installed as a placeholder. When a call starts,
+ * this interceptor checks with the {@link BinaryLogProvider} to see if logging should happen
+ * for this particular {@link ClientCall}'s method.
+ */
+ private final class BinaryLogShim implements ClientInterceptor {
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+ MethodDescriptor<ReqT, RespT> method,
+ CallOptions callOptions,
+ Channel next) {
+ ClientInterceptor binlogInterceptor = getClientInterceptor(
+ method.getFullMethodName(), callOptions);
+ if (binlogInterceptor == null) {
+ return next.newCall(method, callOptions);
+ } else {
+ return InternalClientInterceptors
+ .wrapClientInterceptor(
+ binlogInterceptor,
+ BYTEARRAY_MARSHALLER,
+ BYTEARRAY_MARSHALLER)
+ .interceptCall(method, callOptions, next);
+ }
+ }
+ }
+
+ // Copied from internal
+ private static final class IoUtils {
+ /** maximum buffer to be read is 16 KB. */
+ private static final int MAX_BUFFER_LENGTH = 16384;
+
+ /** Returns the byte array. */
+ public static byte[] toByteArray(InputStream in) throws IOException {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ copy(in, out);
+ return out.toByteArray();
+ }
+
+ /** Copies the data from input stream to output stream. */
+ public static long copy(InputStream from, OutputStream to) throws IOException {
+ // Copied from guava com.google.common.io.ByteStreams because its API is unstable (beta)
+ Preconditions.checkNotNull(from);
+ Preconditions.checkNotNull(to);
+ byte[] buf = new byte[MAX_BUFFER_LENGTH];
+ long total = 0;
+ while (true) {
+ int r = from.read(buf);
+ if (r == -1) {
+ break;
+ }
+ to.write(buf, 0, r);
+ total += r;
+ }
+ return total;
+ }
+ }
+}
diff --git a/services/src/main/java/io/grpc/services/BinaryLogProviderImpl.java b/services/src/main/java/io/grpc/services/BinaryLogProviderImpl.java
index 826e2af8d..55109bc92 100644
--- a/services/src/main/java/io/grpc/services/BinaryLogProviderImpl.java
+++ b/services/src/main/java/io/grpc/services/BinaryLogProviderImpl.java
@@ -17,7 +17,6 @@
package io.grpc.services;
import com.google.common.base.Preconditions;
-import io.grpc.BinaryLogProvider;
import io.grpc.CallOptions;
import io.grpc.ClientInterceptor;
import io.grpc.ServerInterceptor;
diff --git a/services/src/main/java/io/grpc/services/BinlogHelper.java b/services/src/main/java/io/grpc/services/BinlogHelper.java
index 9d06cd5f9..1d8d1443f 100644
--- a/services/src/main/java/io/grpc/services/BinlogHelper.java
+++ b/services/src/main/java/io/grpc/services/BinlogHelper.java
@@ -16,7 +16,7 @@
package io.grpc.services;
-import static io.grpc.BinaryLogProvider.BYTEARRAY_MARSHALLER;
+import static io.grpc.services.BinaryLogProvider.BYTEARRAY_MARSHALLER;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
diff --git a/services/src/main/java/io/grpc/services/CensusBinaryLogProvider.java b/services/src/main/java/io/grpc/services/CensusBinaryLogProvider.java
index f255d67e8..db6a0b62d 100644
--- a/services/src/main/java/io/grpc/services/CensusBinaryLogProvider.java
+++ b/services/src/main/java/io/grpc/services/CensusBinaryLogProvider.java
@@ -16,7 +16,6 @@
package io.grpc.services;
-import io.grpc.BinaryLogProvider;
import io.grpc.CallOptions;
import io.opencensus.trace.Span;
import io.opencensus.trace.Tracing;
diff --git a/services/src/test/java/io/grpc/services/BinaryLogProviderTest.java b/services/src/test/java/io/grpc/services/BinaryLogProviderTest.java
new file mode 100644
index 000000000..30530f8b6
--- /dev/null
+++ b/services/src/test/java/io/grpc/services/BinaryLogProviderTest.java
@@ -0,0 +1,439 @@
+/*
+ * Copyright 2017 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.services;
+
+import static com.google.common.base.Charsets.UTF_8;
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.io.ByteStreams;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
+import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
+import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
+import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.MethodDescriptor.Marshaller;
+import io.grpc.MethodDescriptor.MethodType;
+import io.grpc.ServerCall;
+import io.grpc.ServerCall.Listener;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.ServerMethodDefinition;
+import io.grpc.internal.IoUtils;
+import io.grpc.internal.NoopClientCall;
+import io.grpc.internal.NoopServerCall;
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link BinaryLogProvider}. */
+@RunWith(JUnit4.class)
+public class BinaryLogProviderTest {
+ private final InvocationCountMarshaller<String> reqMarshaller =
+ new InvocationCountMarshaller<String>() {
+ @Override
+ Marshaller<String> delegate() {
+ return StringMarshaller.INSTANCE;
+ }
+ };
+ private final InvocationCountMarshaller<Integer> respMarshaller =
+ new InvocationCountMarshaller<Integer>() {
+ @Override
+ Marshaller<Integer> delegate() {
+ return IntegerMarshaller.INSTANCE;
+ }
+ };
+ private final MethodDescriptor<String, Integer> method =
+ MethodDescriptor
+ .newBuilder(reqMarshaller, respMarshaller)
+ .setFullMethodName("myservice/mymethod")
+ .setType(MethodType.UNARY)
+ .setSchemaDescriptor(new Object())
+ .setIdempotent(true)
+ .setSafe(true)
+ .setSampledToLocalTracing(true)
+ .build();
+ private final List<byte[]> binlogReq = new ArrayList<byte[]>();
+ private final List<byte[]> binlogResp = new ArrayList<byte[]>();
+ private final BinaryLogProvider binlogProvider = new BinaryLogProvider() {
+ @Override
+ public ServerInterceptor getServerInterceptor(String fullMethodName) {
+ return new TestBinaryLogServerInterceptor();
+ }
+
+ @Override
+ public ClientInterceptor getClientInterceptor(
+ String fullMethodName, CallOptions callOptions) {
+ return new TestBinaryLogClientInterceptor();
+ }
+ };
+
+ @Test
+ public void wrapChannel_methodDescriptor() throws Exception {
+ final AtomicReference<MethodDescriptor<?, ?>> methodRef =
+ new AtomicReference<MethodDescriptor<?, ?>>();
+ Channel channel = new Channel() {
+ @Override
+ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
+ MethodDescriptor<RequestT, ResponseT> method, CallOptions callOptions) {
+ methodRef.set(method);
+ return new NoopClientCall<RequestT, ResponseT>();
+ }
+
+ @Override
+ public String authority() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ Channel wChannel = binlogProvider.wrapChannel(channel);
+ ClientCall<String, Integer> unusedClientCall = wChannel.newCall(method, CallOptions.DEFAULT);
+ validateWrappedMethod(methodRef.get());
+ }
+
+ @Test
+ public void wrapChannel_handler() throws Exception {
+ final List<byte[]> serializedReq = new ArrayList<byte[]>();
+ final AtomicReference<ClientCall.Listener<?>> listener =
+ new AtomicReference<ClientCall.Listener<?>>();
+ Channel channel = new Channel() {
+ @Override
+ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
+ MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
+ return new NoopClientCall<RequestT, ResponseT>() {
+ @Override
+ public void start(Listener<ResponseT> responseListener, Metadata headers) {
+ listener.set(responseListener);
+ }
+
+ @Override
+ public void sendMessage(RequestT message) {
+ serializedReq.add((byte[]) message);
+ }
+ };
+ }
+
+ @Override
+ public String authority() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ Channel wChannel = binlogProvider.wrapChannel(channel);
+ ClientCall<String, Integer> clientCall = wChannel.newCall(method, CallOptions.DEFAULT);
+ final List<Integer> observedResponse = new ArrayList<Integer>();
+ clientCall.start(
+ new NoopClientCall.NoopClientCallListener<Integer>() {
+ @Override
+ public void onMessage(Integer message) {
+ observedResponse.add(message);
+ }
+ },
+ new Metadata());
+
+ String expectedRequest = "hello world";
+ assertThat(binlogReq).isEmpty();
+ assertThat(serializedReq).isEmpty();
+ assertEquals(0, reqMarshaller.streamInvocations);
+ clientCall.sendMessage(expectedRequest);
+ // it is unacceptably expensive for the binlog to double parse every logged message
+ assertEquals(1, reqMarshaller.streamInvocations);
+ assertEquals(0, reqMarshaller.parseInvocations);
+ assertThat(binlogReq).hasSize(1);
+ assertThat(serializedReq).hasSize(1);
+ assertEquals(
+ expectedRequest,
+ StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogReq.get(0))));
+ assertEquals(
+ expectedRequest,
+ StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(serializedReq.get(0))));
+
+ int expectedResponse = 12345;
+ assertThat(binlogResp).isEmpty();
+ assertThat(observedResponse).isEmpty();
+ assertEquals(0, respMarshaller.parseInvocations);
+ onClientMessageHelper(listener.get(), IntegerMarshaller.INSTANCE.stream(expectedResponse));
+ // it is unacceptably expensive for the binlog to double parse every logged message
+ assertEquals(1, respMarshaller.parseInvocations);
+ assertEquals(0, respMarshaller.streamInvocations);
+ assertThat(binlogResp).hasSize(1);
+ assertThat(observedResponse).hasSize(1);
+ assertEquals(
+ expectedResponse,
+ (int) IntegerMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogResp.get(0))));
+ assertEquals(expectedResponse, (int) observedResponse.get(0));
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ private static void onClientMessageHelper(ClientCall.Listener listener, Object request) {
+ listener.onMessage(request);
+ }
+
+ private void validateWrappedMethod(MethodDescriptor<?, ?> wMethod) {
+ assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, wMethod.getRequestMarshaller());
+ assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, wMethod.getResponseMarshaller());
+ assertEquals(method.getType(), wMethod.getType());
+ assertEquals(method.getFullMethodName(), wMethod.getFullMethodName());
+ assertEquals(method.getSchemaDescriptor(), wMethod.getSchemaDescriptor());
+ assertEquals(method.isIdempotent(), wMethod.isIdempotent());
+ assertEquals(method.isSafe(), wMethod.isSafe());
+ assertEquals(method.isSampledToLocalTracing(), wMethod.isSampledToLocalTracing());
+ }
+
+ @Test
+ public void wrapMethodDefinition_methodDescriptor() throws Exception {
+ ServerMethodDefinition<String, Integer> methodDef =
+ ServerMethodDefinition.create(
+ method,
+ new ServerCallHandler<String, Integer>() {
+ @Override
+ public Listener<String> startCall(
+ ServerCall<String, Integer> call, Metadata headers) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ ServerMethodDefinition<?, ?> wMethodDef = binlogProvider.wrapMethodDefinition(methodDef);
+ validateWrappedMethod(wMethodDef.getMethodDescriptor());
+ }
+
+ @Test
+ public void wrapMethodDefinition_handler() throws Exception {
+ // The request as seen by the user supplied server code
+ final List<String> observedRequest = new ArrayList<String>();
+ final AtomicReference<ServerCall<String, Integer>> serverCall =
+ new AtomicReference<ServerCall<String, Integer>>();
+ ServerMethodDefinition<String, Integer> methodDef =
+ ServerMethodDefinition.create(
+ method,
+ new ServerCallHandler<String, Integer>() {
+ @Override
+ public ServerCall.Listener<String> startCall(
+ ServerCall<String, Integer> call, Metadata headers) {
+ serverCall.set(call);
+ return new ServerCall.Listener<String>() {
+ @Override
+ public void onMessage(String message) {
+ observedRequest.add(message);
+ }
+ };
+ }
+ });
+ ServerMethodDefinition<?, ?> wDef = binlogProvider.wrapMethodDefinition(methodDef);
+ List<Object> serializedResp = new ArrayList<Object>();
+ ServerCall.Listener<?> wListener = startServerCallHelper(wDef, serializedResp);
+
+ String expectedRequest = "hello world";
+ assertThat(binlogReq).isEmpty();
+ assertThat(observedRequest).isEmpty();
+ assertEquals(0, reqMarshaller.parseInvocations);
+ onServerMessageHelper(wListener, StringMarshaller.INSTANCE.stream(expectedRequest));
+ // it is unacceptably expensive for the binlog to double parse every logged message
+ assertEquals(1, reqMarshaller.parseInvocations);
+ assertEquals(0, reqMarshaller.streamInvocations);
+ assertThat(binlogReq).hasSize(1);
+ assertThat(observedRequest).hasSize(1);
+ assertEquals(
+ expectedRequest,
+ StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogReq.get(0))));
+ assertEquals(expectedRequest, observedRequest.get(0));
+
+ int expectedResponse = 12345;
+ assertThat(binlogResp).isEmpty();
+ assertThat(serializedResp).isEmpty();
+ assertEquals(0, respMarshaller.streamInvocations);
+ serverCall.get().sendMessage(expectedResponse);
+ // it is unacceptably expensive for the binlog to double parse every logged message
+ assertEquals(0, respMarshaller.parseInvocations);
+ assertEquals(1, respMarshaller.streamInvocations);
+ assertThat(binlogResp).hasSize(1);
+ assertThat(serializedResp).hasSize(1);
+ assertEquals(
+ expectedResponse,
+ (int) IntegerMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogResp.get(0))));
+ assertEquals(expectedResponse,
+ (int) method.parseResponse(new ByteArrayInputStream((byte[]) serializedResp.get(0))));
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ private static void onServerMessageHelper(ServerCall.Listener listener, Object request) {
+ listener.onMessage(request);
+ }
+
+ private static <ReqT, RespT> ServerCall.Listener<ReqT> startServerCallHelper(
+ final ServerMethodDefinition<ReqT, RespT> methodDef,
+ final List<Object> serializedResp) {
+ ServerCall<ReqT, RespT> serverCall = new NoopServerCall<ReqT, RespT>() {
+ @Override
+ public void sendMessage(RespT message) {
+ serializedResp.add(message);
+ }
+
+ @Override
+ public MethodDescriptor<ReqT, RespT> getMethodDescriptor() {
+ return methodDef.getMethodDescriptor();
+ }
+ };
+ return methodDef.getServerCallHandler().startCall(serverCall, new Metadata());
+ }
+
+ private final class TestBinaryLogClientInterceptor implements ClientInterceptor {
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+ final MethodDescriptor<ReqT, RespT> method,
+ CallOptions callOptions,
+ Channel next) {
+ assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, method.getRequestMarshaller());
+ assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, method.getResponseMarshaller());
+ return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
+ @Override
+ public void start(Listener<RespT> responseListener, Metadata headers) {
+ super.start(
+ new SimpleForwardingClientCallListener<RespT>(responseListener) {
+ @Override
+ public void onMessage(RespT message) {
+ assertTrue(message instanceof InputStream);
+ try {
+ byte[] bytes = IoUtils.toByteArray((InputStream) message);
+ binlogResp.add(bytes);
+ ByteArrayInputStream input = new ByteArrayInputStream(bytes);
+ RespT dup = method.parseResponse(input);
+ super.onMessage(dup);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ },
+ headers);
+ }
+
+ @Override
+ public void sendMessage(ReqT message) {
+ byte[] bytes = (byte[]) message;
+ binlogReq.add(bytes);
+ ByteArrayInputStream input = new ByteArrayInputStream(bytes);
+ ReqT dup = method.parseRequest(input);
+ super.sendMessage(dup);
+ }
+ };
+ }
+ }
+
+ private final class TestBinaryLogServerInterceptor implements ServerInterceptor {
+ @Override
+ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+ final ServerCall<ReqT, RespT> call,
+ Metadata headers,
+ ServerCallHandler<ReqT, RespT> next) {
+ assertSame(
+ BinaryLogProvider.BYTEARRAY_MARSHALLER,
+ call.getMethodDescriptor().getRequestMarshaller());
+ assertSame(
+ BinaryLogProvider.BYTEARRAY_MARSHALLER,
+ call.getMethodDescriptor().getResponseMarshaller());
+ ServerCall<ReqT, RespT> wCall = new SimpleForwardingServerCall<ReqT, RespT>(call) {
+ @Override
+ public void sendMessage(RespT message) {
+ byte[] bytes = (byte[]) message;
+ binlogResp.add(bytes);
+ ByteArrayInputStream input = new ByteArrayInputStream(bytes);
+ RespT dup = call.getMethodDescriptor().parseResponse(input);
+ super.sendMessage(dup);
+ }
+ };
+ final ServerCall.Listener<ReqT> oListener = next.startCall(wCall, headers);
+ return new SimpleForwardingServerCallListener<ReqT>(oListener) {
+ @Override
+ public void onMessage(ReqT message) {
+ assertTrue(message instanceof InputStream);
+ try {
+ byte[] bytes = IoUtils.toByteArray((InputStream) message);
+ binlogReq.add(bytes);
+ ByteArrayInputStream input = new ByteArrayInputStream(bytes);
+ ReqT dup = call.getMethodDescriptor().parseRequest(input);
+ super.onMessage(dup);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+ }
+ }
+
+ private abstract static class InvocationCountMarshaller<T>
+ implements MethodDescriptor.Marshaller<T> {
+ private int streamInvocations = 0;
+ private int parseInvocations = 0;
+
+ abstract MethodDescriptor.Marshaller<T> delegate();
+
+ @Override
+ public InputStream stream(T value) {
+ streamInvocations++;
+ return delegate().stream(value);
+ }
+
+ @Override
+ public T parse(InputStream stream) {
+ parseInvocations++;
+ return delegate().parse(stream);
+ }
+ }
+
+
+ private static class StringMarshaller implements MethodDescriptor.Marshaller<String> {
+ public static StringMarshaller INSTANCE = new StringMarshaller();
+
+ @Override
+ public InputStream stream(String value) {
+ return new ByteArrayInputStream(value.getBytes(UTF_8));
+ }
+
+ @Override
+ public String parse(InputStream stream) {
+ try {
+ return new String(ByteStreams.toByteArray(stream), UTF_8);
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+ }
+
+ private static class IntegerMarshaller implements MethodDescriptor.Marshaller<Integer> {
+ public static final IntegerMarshaller INSTANCE = new IntegerMarshaller();
+
+ @Override
+ public InputStream stream(Integer value) {
+ return StringMarshaller.INSTANCE.stream(value.toString());
+ }
+
+ @Override
+ public Integer parse(InputStream stream) {
+ return Integer.valueOf(StringMarshaller.INSTANCE.parse(stream));
+ }
+ }
+}
diff --git a/services/src/test/java/io/grpc/services/BinlogHelperTest.java b/services/src/test/java/io/grpc/services/BinlogHelperTest.java
index a11502a49..c19d0182a 100644
--- a/services/src/test/java/io/grpc/services/BinlogHelperTest.java
+++ b/services/src/test/java/io/grpc/services/BinlogHelperTest.java
@@ -16,7 +16,7 @@
package io.grpc.services;
-import static io.grpc.BinaryLogProvider.BYTEARRAY_MARSHALLER;
+import static io.grpc.services.BinaryLogProvider.BYTEARRAY_MARSHALLER;
import static io.grpc.services.BinlogHelper.DUMMY_SOCKET;
import static io.grpc.services.BinlogHelper.getPeerSocket;
import static org.junit.Assert.assertEquals;
@@ -31,7 +31,6 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.protobuf.ByteString;
import io.grpc.Attributes;
import io.grpc.BinaryLog.CallId;
-import io.grpc.BinaryLogProvider;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
diff --git a/services/src/test/java/io/grpc/services/CensusBinaryLogProviderTest.java b/services/src/test/java/io/grpc/services/CensusBinaryLogProviderTest.java
index 1c59026bd..8fabbd1f9 100644
--- a/services/src/test/java/io/grpc/services/CensusBinaryLogProviderTest.java
+++ b/services/src/test/java/io/grpc/services/CensusBinaryLogProviderTest.java
@@ -20,7 +20,6 @@ import static com.google.common.truth.Truth.assertThat;
import static io.opencensus.trace.unsafe.ContextUtils.CONTEXT_SPAN_KEY;
import io.grpc.BinaryLog.CallId;
-import io.grpc.BinaryLogProvider;
import io.grpc.CallOptions;
import io.grpc.Context;
import io.grpc.internal.testing.StatsTestUtils.MockableSpan;