aboutsummaryrefslogtreecommitdiff
path: root/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java
blob: 2a88a3f891334030511c02028c835267b7fca3da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
/*
 * Copyright 2014, gRPC Authors All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.protobuf.lite;

import static com.google.common.base.Preconditions.checkNotNull;

import com.google.protobuf.CodedInputStream;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import io.grpc.ExperimentalApi;
import io.grpc.KnownLength;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.PrototypeMarshaller;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import java.io.IOException;
import java.io.InputStream;
import java.lang.ref.Reference;
import java.lang.ref.WeakReference;

/**
 * Utility methods for using protobuf with grpc.
 */
@ExperimentalApi("Experimental until Lite is stable in protobuf")
public class ProtoLiteUtils {

  private static volatile ExtensionRegistryLite globalRegistry =
      ExtensionRegistryLite.getEmptyRegistry();

  /**
   * Sets the global registry for proto marshalling shared across all servers and clients.
   *
   * <p>Warning:  This API will likely change over time.  It is not possible to have separate
   * registries per Process, Server, Channel, Service, or Method.  This is intentional until there
   * is a more appropriate API to set them.
   *
   * <p>Warning:  Do NOT modify the extension registry after setting it.  It is thread safe to call
   * {@link #setExtensionRegistry}, but not to modify the underlying object.
   *
   * <p>If you need custom parsing behavior for protos, you will need to make your own
   * {@code MethodDescriptor.Marshaller} for the time being.
   *
   */
  @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1787")
  public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) {
    globalRegistry = checkNotNull(newRegistry, "newRegistry");
  }

  private static final ThreadLocal<Reference<byte[]>> bufs = new ThreadLocal<Reference<byte[]>>() {

    @Override
    protected Reference<byte[]> initialValue() {
      return new WeakReference<byte[]>(new byte[4096]); // Picked at random.
    }
  };

  /** Create a {@code Marshaller} for protos of the same type as {@code defaultInstance}. */
  public static <T extends MessageLite> Marshaller<T> marshaller(final T defaultInstance) {
    @SuppressWarnings("unchecked")
    final Parser<T> parser = (Parser<T>) defaultInstance.getParserForType();
    // TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe)
    return new PrototypeMarshaller<T>() {
      @SuppressWarnings("unchecked")
      @Override
      public Class<T> getMessageClass() {
        // Precisely T since protobuf doesn't let messages extend other messages.
        return (Class<T>) defaultInstance.getClass();
      }

      @Override
      public T getMessagePrototype() {
        return defaultInstance;
      }

      @Override
      public InputStream stream(T value) {
        return new ProtoInputStream(value, parser);
      }

      @Override
      public T parse(InputStream stream) {
        if (stream instanceof ProtoInputStream) {
          ProtoInputStream protoStream = (ProtoInputStream) stream;
          // Optimization for in-memory transport. Returning provided object is safe since protobufs
          // are immutable.
          //
          // However, we can't assume the types match, so we have to verify the parser matches.
          // Today the parser is always the same for a given proto, but that isn't guaranteed. Even
          // if not, using the same MethodDescriptor would ensure the parser matches and permit us
          // to enable this optimization.
          if (protoStream.parser() == parser) {
            try {
              @SuppressWarnings("unchecked")
              T message = (T) ((ProtoInputStream) stream).message();
              return message;
            } catch (IllegalStateException ex) {
              // Stream must have been read from, which is a strange state. Since the point of this
              // optimization is to be transparent, instead of throwing an error we'll continue,
              // even though it seems likely there's a bug.
            }
          }
        }
        CodedInputStream cis = null;
        try {
          if (stream instanceof KnownLength) {
            int size = stream.available();
            if (size > 0 && size <= GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE) {
              // buf should not be used after this method has returned.
              byte[] buf = bufs.get().get();
              if (buf == null || buf.length < size) {
                buf = new byte[size];
                bufs.set(new WeakReference<byte[]>(buf));
              }

              int remaining = size;
              while (remaining > 0) {
                int position = size - remaining;
                int count = stream.read(buf, position, remaining);
                if (count == -1) {
                  break;
                }
                remaining -= count;
              }

              if (remaining != 0) {
                int position = size - remaining;
                throw new RuntimeException("size inaccurate: " + size + " != " + position);
              }
              cis = CodedInputStream.newInstance(buf, 0, size);
            } else if (size == 0) {
              return defaultInstance;
            }
          }
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
        if (cis == null) {
          cis = CodedInputStream.newInstance(stream);
        }
        // Pre-create the CodedInputStream so that we can remove the size limit restriction
        // when parsing.
        cis.setSizeLimit(Integer.MAX_VALUE);

        try {
          return parseFrom(cis);
        } catch (InvalidProtocolBufferException ipbe) {
          throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence")
            .withCause(ipbe).asRuntimeException();
        }
      }

      private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException {
        T message = parser.parseFrom(stream, globalRegistry);
        try {
          stream.checkLastTagWas(0);
          return message;
        } catch (InvalidProtocolBufferException e) {
          e.setUnfinishedMessage(message);
          throw e;
        }
      }
    };
  }

  /**
   * Produce a metadata marshaller for a protobuf type.
   */
  public static <T extends MessageLite> Metadata.BinaryMarshaller<T> metadataMarshaller(
      final T instance) {
    return new Metadata.BinaryMarshaller<T>() {
      @Override
      public byte[] toBytes(T value) {
        return value.toByteArray();
      }

      @Override
      @SuppressWarnings("unchecked")
      public T parseBytes(byte[] serialized) {
        try {
          return (T) instance.getParserForType().parseFrom(serialized, globalRegistry);
        } catch (InvalidProtocolBufferException ipbe) {
          throw new IllegalArgumentException(ipbe);
        }
      }
    };
  }

  private ProtoLiteUtils() {
  }
}