aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-09-23 10:13:36 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-09-23 10:13:36 +0000
commiteb2b74757c9aed697fb27c5fe1f585e98dbe50b2 (patch)
treef98fc6a3f62fae52700dc6e6dc1e68e432a18f2f
parentbef5a8f859ed8de1bef279aeff89a419d9cb55fe (diff)
parentbe767e5d685ec77e54ad9b2ba2f63839b78a1cd0 (diff)
downloadtensorflow-android13-mainline-scheduling-release.tar.gz
Snap for 9098257 from be767e5d685ec77e54ad9b2ba2f63839b78a1cd0 to mainline-scheduling-releaseaml_sch_331113000aml_sch_331111000android13-mainline-scheduling-release
Change-Id: Iba47a12031d5fff2aa172b6cd200e9a8964bdc8c
-rw-r--r--tensorflow/lite/builtin_ops.h1
-rw-r--r--tensorflow/lite/c/builtin_op_data.h4
-rw-r--r--tensorflow/lite/core/api/flatbuffer_conversions.cc15
-rw-r--r--tensorflow/lite/core/shims/builtin_ops_list.inc1
-rw-r--r--tensorflow/lite/kernels/activations.cc55
-rw-r--r--tensorflow/lite/kernels/activations_test.cc165
-rw-r--r--tensorflow/lite/kernels/builtin_op_kernels.h1
-rw-r--r--tensorflow/lite/kernels/internal/constants.h61
-rw-r--r--tensorflow/lite/kernels/internal/reference/gelu.h82
-rw-r--r--tensorflow/lite/kernels/register.cc3
-rw-r--r--tensorflow/lite/kernels/register_ref.cc4
-rw-r--r--tensorflow/lite/schema/schema.fbs6
-rwxr-xr-xtensorflow/lite/schema/schema_generated.h165
-rw-r--r--tensorflow/lite/tools/optimize/operator_property.cc5
-rw-r--r--tensorflow/lite/tools/serialization/option_writer_generator.cc2
-rw-r--r--tensorflow/lite/tools/versioning/op_version.cc6
-rw-r--r--tensorflow/lite/tools/versioning/op_version_test.cc14
-rw-r--r--tensorflow/lite/tools/versioning/runtime_version.cc2
18 files changed, 583 insertions, 9 deletions
diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h
index 505cbdbcd12..1a3ec6b359e 100644
--- a/tensorflow/lite/builtin_ops.h
+++ b/tensorflow/lite/builtin_ops.h
@@ -173,6 +173,7 @@ typedef enum {
kTfLiteBuiltinReadVariable = 143,
kTfLiteBuiltinAssignVariable = 144,
kTfLiteBuiltinBroadcastArgs = 145,
+ kTfLiteBuiltinGelu = 150,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h
index ed5ac004cbd..8e49c91ad5f 100644
--- a/tensorflow/lite/c/builtin_op_data.h
+++ b/tensorflow/lite/c/builtin_op_data.h
@@ -502,6 +502,10 @@ typedef struct {
const char* shared_name;
} TfLiteVarHandleParams;
+typedef struct {
+ bool approximate;
+} TfLiteGeluParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index da714794a12..2897728123e 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -795,6 +795,21 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
+ case BuiltinOperator_GELU: {
+ auto params = safe_allocator.Allocate<TfLiteGeluParams>();
+ TF_LITE_ENSURE(error_reporter, params != nullptr);
+ if (const auto* gelu_params = op->builtin_options_as_GeluOptions()) {
+ params->approximate = gelu_params->approximate();
+ }
+ *builtin_data = params.release();
+ return kTfLiteOk;
+ }
+ // Unsupported builtins.
+ case BuiltinOperator_RANDOM_STANDARD_NORMAL:
+ case BuiltinOperator_BUCKETIZE:
+ case BuiltinOperator_RANDOM_UNIFORM:
+ case BuiltinOperator_MULTINOMIAL:
+ return kTfLiteError;
// Below are the ops with no builtin_data structure.
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
// ok for now, since there is no call implementation either.
diff --git a/tensorflow/lite/core/shims/builtin_ops_list.inc b/tensorflow/lite/core/shims/builtin_ops_list.inc
index b96e60afa6e..bb5110a3fc3 100644
--- a/tensorflow/lite/core/shims/builtin_ops_list.inc
+++ b/tensorflow/lite/core/shims/builtin_ops_list.inc
@@ -158,3 +158,4 @@ TFLITE_OP(Register_VAR_HANDLE)
TFLITE_OP(Register_READ_VARIABLE)
TFLITE_OP(Register_ASSIGN_VARIABLE)
TFLITE_OP(Register_BROADCAST_ARGS)
+TFLITE_OP(Register_GELU)
diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc
index b13a2dac992..f4ded17d8f6 100644
--- a/tensorflow/lite/kernels/activations.cc
+++ b/tensorflow/lite/kernels/activations.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
+#include "tensorflow/lite/kernels/internal/reference/gelu.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
@@ -1501,6 +1502,53 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus GeluPrepare(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input;
+ TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+ auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data);
+
+ if (input->type == kTfLiteInt8) {
+ PopulateLookupTable<int8_t>(
+ data, input, output, reference_ops::GeluTransform(params->approximate));
+ } else if (input->type == kTfLiteUInt8) {
+ PopulateLookupTable<uint8_t>(
+ data, input, output, reference_ops::GeluTransform(params->approximate));
+ }
+ return GenericPrepare(context, node);
+}
+
+TfLiteStatus GeluEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data);
+ const TfLiteTensor* input;
+ TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
+ TfLiteTensor* output;
+ TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
+
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ reference_ops::Gelu(GetTensorShape(input), GetTensorData<float>(input),
+ params->approximate, GetTensorShape(output),
+ GetTensorData<float>(output));
+ return kTfLiteOk;
+ }
+ case kTfLiteInt8:
+ case kTfLiteUInt8: {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+ EvalUsingLookupTable(data, input, output);
+ return kTfLiteOk;
+ }
+ default:
+ TF_LITE_KERNEL_LOG(
+ context, "Only float32, int8 and uint8 supported currently, got %s.",
+ TfLiteTypeGetName(input->type));
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace activations
TfLiteRegistration* Register_ELU() {
@@ -1661,6 +1709,13 @@ TfLiteRegistration* Register_HARD_SWISH_REF() {
return &r;
}
+TfLiteRegistration* Register_GELU() {
+ static TfLiteRegistration r = {activations::Init, activations::Free,
+ activations::GeluPrepare,
+ activations::GeluEval};
+ return &r;
+}
+
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc
index f629775dd4a..3bf96fadff3 100644
--- a/tensorflow/lite/kernels/activations_test.cc
+++ b/tensorflow/lite/kernels/activations_test.cc
@@ -2564,6 +2564,171 @@ TEST(FloatActivationsOpTest, LeakyRelu) {
}));
}
+
+class GeluOpModel : public SingleOpModel {
+ public:
+ GeluOpModel(const TensorData& input, bool approximate) {
+ input_ = AddInput(input);
+ output_ = AddOutput(input);
+ SetBuiltinOp(BuiltinOperator_GELU, BuiltinOptions_GeluOptions,
+ CreateGeluOptions(builder_, approximate).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class BaseGeluOpModel : public SingleOpModel {
+ public:
+ BaseGeluOpModel(const TensorData& input, bool approximate) {
+ input_ = AddInput(input);
+ approximate_ = approximate;
+ output_ = AddOutput({input.type, input.shape, input.min, input.max});
+ SetBuiltinOp(BuiltinOperator_GELU, BuiltinOptions_GeluOptions,
+ CreateGeluOptions(builder_, approximate).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+
+ bool approximate_;
+ int output_;
+};
+
+// The FloatGeluOpModel class handles float input and output.
+class FloatGeluOpModel : public BaseGeluOpModel {
+ public:
+ using BaseGeluOpModel::BaseGeluOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+// The QuantizedGeluOpModel class handles quantized input and output.
+class QuantizedGeluOpModel : public BaseGeluOpModel {
+ public:
+ using BaseGeluOpModel::BaseGeluOpModel;
+
+ template <typename T>
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<T>(input_, data);
+ }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ template <typename T>
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
+ GetZeroPoint(output_));
+ }
+};
+
+TEST(FloatActivationsOpTest, Gelu) {
+ FloatGeluOpModel m({TensorType_FLOAT32, {2, 3}}, /*approximate=*/false);
+
+ m.SetInput({
+ 0.0f, 1.0f, 3.0f, // Row 1
+ 1.0f, -1.0f, -2.0f, // Row 2
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.0f, 0.841345f, 2.99595f, // Row 1
+ 0.841345f, -0.158655f, -0.0455003f, // Row 2
+ })));
+}
+
+TEST(FloatActivationsOpTest, GeluApproximate) {
+ FloatGeluOpModel m({TensorType_FLOAT32, {2, 3}}, /*approximate=*/true);
+
+ m.SetInput({
+ 0.0f, 1.0f, 3.0f, // Row 1
+ 1.0f, -1.0f, -2.0f, // Row 2
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.0f, 0.841192f, 2.99636f, // Row 1
+ 0.841192f, -0.158808f, -0.0454023f, // Row 2
+ })));
+}
+
+TEST(QuantizedGeluOpTest, GeluInt8) {
+ const float kMin = -1;
+ const float kMax = 127.f / 128.f;
+ QuantizedGeluOpModel m({TensorType_INT8, {2, 3}, 3 * kMin, 3 * kMax},
+ /*approximate=*/false);
+ m.SetInput<int8_t>({
+ 0.0f, 1.0f, 3.0f, // Row 1
+ 1.0f, -1.0f, -2.0f, // Row 2
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 0.f, 0.84375f, 2.97656f, // Row 1
+ 0.84375f, -0.164062f, -0.046875f // Row 2
+ })));
+}
+
+TEST(QuantizedGeluOpTest, GeluInt8Approximate) {
+ const float kMin = -1;
+ const float kMax = 127.f / 128.f;
+ QuantizedGeluOpModel m({TensorType_INT8, {2, 3}, 3 * kMin, 3 * kMax},
+ /*approximate=*/true);
+ m.SetInput<int8_t>({
+ 0.0f, 1.0f, 3.0f, // Row 1
+ 1.0f, -1.0f, -2.0f, // Row 2
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 0.f, 0.84375f, 2.97656f, // Row 1
+ 0.84375f, -0.164062f, -0.046875f // Row 2
+ })));
+}
+TEST(QuantizedGeluOpTest, GeluUInt8) {
+ const float kMin = -1;
+ const float kMax = 127.f / 128.f;
+ QuantizedGeluOpModel m({TensorType_UINT8, {2, 3}, 3 * kMin, 3 * kMax},
+ /*approximate=*/false);
+ m.SetInput<uint8_t>({
+ 0.0f, 1.0f, 3.0f, // Row 1
+ 1.0f, -1.0f, -2.0f, // Row 2
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 0.f, 0.84375f, 2.97656f, // Row 1
+ 0.84375f, -0.164062f, -0.046875f // Row 2
+ })));
+}
+
+TEST(QuantizedGeluOpTest, GeluUInt8Approximate) {
+ const float kMin = -1;
+ const float kMax = 127.f / 128.f;
+ QuantizedGeluOpModel m({TensorType_UINT8, {2, 3}, 3 * kMin, 3 * kMax},
+ /*approximate=*/true);
+ m.SetInput<uint8_t>({
+ 0.0f, 1.0f, 3.0f, // Row 1
+ 1.0f, -1.0f, -2.0f, // Row 2
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 0.f, 0.84375f, 2.97656f, // Row 1
+ 0.84375f, -0.164062f, -0.046875f // Row 2
+ })));
+}
+
+
INSTANTIATE_TEST_SUITE_P(
TanhOpTest, TanhOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kTanhKernelMap)));
diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h
index 85cc9b92a0d..834045621d2 100644
--- a/tensorflow/lite/kernels/builtin_op_kernels.h
+++ b/tensorflow/lite/kernels/builtin_op_kernels.h
@@ -74,6 +74,7 @@ TfLiteRegistration* Register_FLOOR_MOD();
TfLiteRegistration* Register_FULLY_CONNECTED();
TfLiteRegistration* Register_GATHER();
TfLiteRegistration* Register_GATHER_ND();
+TfLiteRegistration* Register_GELU();
TfLiteRegistration* Register_GREATER();
TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_HARD_SWISH();
diff --git a/tensorflow/lite/kernels/internal/constants.h b/tensorflow/lite/kernels/internal/constants.h
new file mode 100644
index 00000000000..aa8bd5f0860
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/constants.h
@@ -0,0 +1,61 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_CONSTANTS_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_CONSTANTS_H_
+
+// Maths constants.
+// The following macros are not always available on all platforms.
+// E.g. MSVC requires additional compile flag to export those.
+#ifndef M_E
+#define M_E 2.7182818284590452354 /* e */
+#endif
+#ifndef M_LOG2E
+#define M_LOG2E 1.4426950408889634074 /* log_2 e */
+#endif
+#ifndef M_LOG10E
+#define M_LOG10E 0.43429448190325182765 /* log_10 e */
+#endif
+#ifndef M_LN2
+#define M_LN2 0.69314718055994530942 /* log_e 2 */
+#endif
+#ifndef M_LN10
+#define M_LN10 2.30258509299404568402 /* log_e 10 */
+#endif
+#ifndef M_PI
+#define M_PI 3.14159265358979323846 /* pi */
+#endif
+#ifndef M_PI_2
+#define M_PI_2 1.57079632679489661923 /* pi/2 */
+#endif
+#ifndef M_PI_4
+#define M_PI_4 0.78539816339744830962 /* pi/4 */
+#endif
+#ifndef M_1_PI
+#define M_1_PI 0.31830988618379067154 /* 1/pi */
+#endif
+#ifndef M_2_PI
+#define M_2_PI 0.63661977236758134308 /* 2/pi */
+#endif
+#ifndef M_2_SQRTPI
+#define M_2_SQRTPI 1.12837916709551257390 /* 2/sqrt(pi) */
+#endif
+#ifndef M_SQRT2
+#define M_SQRT2 1.41421356237309504880 /* sqrt(2) */
+#endif
+#ifndef M_SQRT1_2
+#define M_SQRT1_2 0.70710678118654752440 /* 1/sqrt(2) */
+#endif
+
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_CONSTANTS_H_ \ No newline at end of file
diff --git a/tensorflow/lite/kernels/internal/reference/gelu.h b/tensorflow/lite/kernels/internal/reference/gelu.h
new file mode 100644
index 00000000000..08e5a33241d
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/gelu.h
@@ -0,0 +1,82 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_GELU_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_GELU_H_
+
+#include <cmath>
+#include <functional>
+
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/constants.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+namespace gelu_internal {
+
+constexpr float kSqrt2dPi = M_2_SQRTPI * M_SQRT1_2; // sqrt( 2 / pi )
+
+} // namespace gelu_internal
+
+// Plain implementation for GELU. Used for populating lookup table.
+inline std::function<float(float)> GeluTransform(bool approximate) {
+ if (approximate) {
+ return [](float in) {
+ // 0.5 * x * ( 1 + tanh( sqrt( 2 / pi ) * ( x + 0.044715 * x^3 ) ) )
+ return 0.5f * in *
+ (1.f + std::tanh(gelu_internal::kSqrt2dPi *
+ // Note: Avoid std::pow for integer exponents
+ // as it leads to much slower performance.
+ (in + 0.044715f * in * in * in)));
+ };
+ } else {
+ return [](float in) {
+ // 0.5 * x * ( 1 + erf( x / sqrt( 2 ) ) )
+ return 0.5f * in * (1.f + std::erf(in * M_SQRT1_2));
+ };
+ }
+}
+
+template <typename T>
+inline void Gelu(const RuntimeShape& input_shape, const T* input_data,
+ bool approximate, const RuntimeShape& output_shape,
+ T* output_data) {
+ auto matching_size = MatchingFlatSize(input_shape, output_shape);
+
+ for (int i = 0; i < matching_size; i++) {
+ const T in = input_data[i];
+ if (approximate) {
+ // 0.5 * x * ( 1 + tanh( sqrt( 2 / pi ) * ( x + 0.044715 * x^3 ) ) )
+ output_data[i] =
+ static_cast<T>(0.5) * in *
+ (static_cast<T>(1) +
+ std::tanh(static_cast<T>(gelu_internal::kSqrt2dPi) *
+ // Note: Avoid std::pow for integer exponents
+ // as it leads to much slower performance.
+ (in + static_cast<T>(0.044715) * in * in * in)));
+ } else {
+ // 0.5 * x * ( 1 + erf( x / sqrt( 2 ) ) )
+ output_data[i] =
+ static_cast<T>(0.5) * in *
+ (static_cast<T>(1) + std::erf(in * static_cast<T>(M_SQRT1_2)));
+ }
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_GELU_H_ \ No newline at end of file
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index 4f5fc7faf78..8e26f1d4849 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -329,6 +329,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_VAR_HANDLE, Register_VAR_HANDLE());
AddBuiltin(BuiltinOperator_READ_VARIABLE, Register_READ_VARIABLE());
AddBuiltin(BuiltinOperator_ASSIGN_VARIABLE, Register_ASSIGN_VARIABLE());
+ AddBuiltin(BuiltinOperator_GELU, Register_GELU(),
+ /* min_version = */ 1,
+ /* max_version = */ 2);
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc
index 889e003e404..d59eb9bdb0d 100644
--- a/tensorflow/lite/kernels/register_ref.cc
+++ b/tensorflow/lite/kernels/register_ref.cc
@@ -164,6 +164,7 @@ TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_COMPLEX_ABS();
TfLiteRegistration* Register_CONV_3D_TRANSPOSE_REF();
TfLiteRegistration* Register_BROADCAST_ARGS();
+TfLiteRegistration* Register_GELU();
namespace {
@@ -478,6 +479,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_CONV_3D_TRANSPOSE,
Register_CONV_3D_TRANSPOSE_REF());
AddBuiltin(BuiltinOperator_BROADCAST_ARGS, Register_BROADCAST_ARGS());
+ AddBuiltin(BuiltinOperator_GELU, Register_GELU(),
+ /* min_version = */ 1,
+ /* max_version = */ 2);
AddCustom("NumericVerify",
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index abd8db0012d..d2262de438f 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -379,6 +379,7 @@ enum BuiltinOperator : int32 {
READ_VARIABLE = 143,
ASSIGN_VARIABLE = 144,
BROADCAST_ARGS = 145,
+ GELU = 150,
}
// LINT.ThenChange(nnapi_linter/linter.proto)
@@ -497,6 +498,7 @@ union BuiltinOptions {
VarHandleOptions,
ReadVariableOptions,
AssignVariableOptions,
+ GeluOptions,
}
enum Padding : byte { SAME, VALID }
@@ -1082,6 +1084,10 @@ table ReadVariableOptions {
table AssignVariableOptions {
}
+table GeluOptions {
+ approximate: bool;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h
index 77253a4e667..6f4e93aae64 100755
--- a/tensorflow/lite/schema/schema_generated.h
+++ b/tensorflow/lite/schema/schema_generated.h
@@ -385,6 +385,9 @@ struct ReadVariableOptionsT;
struct AssignVariableOptions;
struct AssignVariableOptionsT;
+struct GeluOptions;
+struct GeluOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -854,11 +857,16 @@ enum BuiltinOperator {
BuiltinOperator_READ_VARIABLE = 143,
BuiltinOperator_ASSIGN_VARIABLE = 144,
BuiltinOperator_BROADCAST_ARGS = 145,
+ BuiltinOperator_RANDOM_STANDARD_NORMAL = 146,
+ BuiltinOperator_BUCKETIZE = 147,
+ BuiltinOperator_RANDOM_UNIFORM = 148,
+ BuiltinOperator_MULTINOMIAL = 149,
+ BuiltinOperator_GELU = 150,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_BROADCAST_ARGS
+ BuiltinOperator_MAX = BuiltinOperator_GELU
};
-inline const BuiltinOperator (&EnumValuesBuiltinOperator())[146] {
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[151] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -1005,13 +1013,18 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[146] {
BuiltinOperator_VAR_HANDLE,
BuiltinOperator_READ_VARIABLE,
BuiltinOperator_ASSIGN_VARIABLE,
- BuiltinOperator_BROADCAST_ARGS
+ BuiltinOperator_BROADCAST_ARGS,
+ BuiltinOperator_RANDOM_STANDARD_NORMAL,
+ BuiltinOperator_BUCKETIZE,
+ BuiltinOperator_RANDOM_UNIFORM,
+ BuiltinOperator_MULTINOMIAL,
+ BuiltinOperator_GELU
};
return values;
}
inline const char * const *EnumNamesBuiltinOperator() {
- static const char * const names[147] = {
+ static const char * const names[152] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@@ -1158,13 +1171,18 @@ inline const char * const *EnumNamesBuiltinOperator() {
"READ_VARIABLE",
"ASSIGN_VARIABLE",
"BROADCAST_ARGS",
+ "RANDOM_STANDARD_NORMAL",
+ "BUCKETIZE",
+ "RANDOM_UNIFORM",
+ "MULTINOMIAL",
+ "GELU",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
- if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BROADCAST_ARGS)) return "";
+ if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_GELU)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@@ -1284,11 +1302,14 @@ enum BuiltinOptions {
BuiltinOptions_VarHandleOptions = 111,
BuiltinOptions_ReadVariableOptions = 112,
BuiltinOptions_AssignVariableOptions = 113,
+ BuiltinOptions_RandomOptions = 114,
+ BuiltinOptions_BucketizeOptions = 115,
+ BuiltinOptions_GeluOptions = 116,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_AssignVariableOptions
+ BuiltinOptions_MAX = BuiltinOptions_GeluOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[114] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[117] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -1403,13 +1424,16 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[114] {
BuiltinOptions_HashtableSizeOptions,
BuiltinOptions_VarHandleOptions,
BuiltinOptions_ReadVariableOptions,
- BuiltinOptions_AssignVariableOptions
+ BuiltinOptions_AssignVariableOptions,
+ BuiltinOptions_RandomOptions,
+ BuiltinOptions_BucketizeOptions,
+ BuiltinOptions_GeluOptions
};
return values;
}
inline const char * const *EnumNamesBuiltinOptions() {
- static const char * const names[115] = {
+ static const char * const names[118] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@@ -1524,6 +1548,9 @@ inline const char * const *EnumNamesBuiltinOptions() {
"VarHandleOptions",
"ReadVariableOptions",
"AssignVariableOptions",
+ "RandomOptions",
+ "BucketizeOptions",
+ "GeluOptions",
nullptr
};
return names;
@@ -1991,6 +2018,10 @@ template<> struct BuiltinOptionsTraits<tflite::AssignVariableOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_AssignVariableOptions;
};
+template<> struct BuiltinOptionsTraits<tflite::GeluOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -2927,6 +2958,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_AssignVariableOptions ?
reinterpret_cast<const tflite::AssignVariableOptionsT *>(value) : nullptr;
}
+ tflite::GeluOptionsT *AsGeluOptions() {
+ return type == BuiltinOptions_GeluOptions ?
+ reinterpret_cast<tflite::GeluOptionsT *>(value) : nullptr;
+ }
+ const tflite::GeluOptionsT *AsGeluOptions() const {
+ return type == BuiltinOptions_GeluOptions ?
+ reinterpret_cast<const tflite::GeluOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -10343,6 +10382,60 @@ inline flatbuffers::Offset<AssignVariableOptions> CreateAssignVariableOptions(
flatbuffers::Offset<AssignVariableOptions> CreateAssignVariableOptions(flatbuffers::FlatBufferBuilder &_fbb, const AssignVariableOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct GeluOptionsT : public flatbuffers::NativeTable {
+ typedef GeluOptions TableType;
+ bool approximate;
+ GeluOptionsT()
+ : approximate(false) {
+ }
+};
+
+struct GeluOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GeluOptionsT NativeTableType;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_APPROXIMATE = 4
+ };
+ bool approximate() const {
+ return GetField<uint8_t>(VT_APPROXIMATE, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_APPROXIMATE) &&
+ verifier.EndTable();
+ }
+ GeluOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(GeluOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<GeluOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct GeluOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_approximate(bool approximate) {
+ fbb_.AddElement<uint8_t>(GeluOptions::VT_APPROXIMATE, static_cast<uint8_t>(approximate), 0);
+ }
+ explicit GeluOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GeluOptionsBuilder &operator=(const GeluOptionsBuilder &);
+ flatbuffers::Offset<GeluOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<GeluOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<GeluOptions> CreateGeluOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool approximate = false) {
+ GeluOptionsBuilder builder_(_fbb);
+ builder_.add_approximate(approximate);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<GeluOptions> CreateGeluOptions(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
int8_t deprecated_builtin_code;
@@ -10832,6 +10925,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const tflite::AssignVariableOptions *builtin_options_as_AssignVariableOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_AssignVariableOptions ? static_cast<const tflite::AssignVariableOptions *>(builtin_options()) : nullptr;
}
+ const tflite::GeluOptions *builtin_options_as_GeluOptions() const {
+ return builtin_options_type() == tflite::BuiltinOptions_GeluOptions ? static_cast<const tflite::GeluOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -11320,6 +11416,10 @@ template<> inline const tflite::AssignVariableOptions *Operator::builtin_options
return builtin_options_as_AssignVariableOptions();
}
+template<> inline const tflite::GeluOptions *Operator::builtin_options_as<tflite::GeluOptions>() const {
+ return builtin_options_as_GeluOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -15325,6 +15425,32 @@ inline flatbuffers::Offset<AssignVariableOptions> CreateAssignVariableOptions(fl
_fbb);
}
+inline GeluOptionsT *GeluOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new GeluOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void GeluOptions::UnPackTo(GeluOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = approximate(); _o->approximate = _e; }
+}
+
+inline flatbuffers::Offset<GeluOptions> GeluOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateGeluOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<GeluOptions> CreateGeluOptions(flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GeluOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _approximate = _o->approximate;
+ return tflite::CreateGeluOptions(
+ _fbb,
+ _approximate);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -16252,6 +16378,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const tflite::AssignVariableOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_GeluOptions: {
+ auto ptr = reinterpret_cast<const tflite::GeluOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return true;
}
}
@@ -16722,6 +16852,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const tflite::AssignVariableOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_GeluOptions: {
+ auto ptr = reinterpret_cast<const tflite::GeluOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -17180,6 +17314,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const tflite::AssignVariableOptionsT *>(value);
return CreateAssignVariableOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_GeluOptions: {
+ auto ptr = reinterpret_cast<const tflite::GeluOptionsT *>(value);
+ return CreateGeluOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -17638,6 +17776,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new tflite::AssignVariableOptionsT(*reinterpret_cast<tflite::AssignVariableOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_GeluOptions: {
+ value = new tflite::GeluOptionsT(*reinterpret_cast<tflite::GeluOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -18210,6 +18352,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_GeluOptions: {
+ auto ptr = reinterpret_cast<tflite::GeluOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index 1bcd0f27997..b18dcf78c7b 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -1041,6 +1041,11 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
property.outputs = {{0, {}}};
property.version = 1;
break;
+ case BuiltinOperator_GELU:
+ property.inputs = {{0, {}}};
+ property.outputs = {{0, {}}};
+ property.version = 2;
+ break;
default:
// No quantized implementation exists for this operation.
property.quantizable = false;
diff --git a/tensorflow/lite/tools/serialization/option_writer_generator.cc b/tensorflow/lite/tools/serialization/option_writer_generator.cc
index 8875e287609..f798cd710b8 100644
--- a/tensorflow/lite/tools/serialization/option_writer_generator.cc
+++ b/tensorflow/lite/tools/serialization/option_writer_generator.cc
@@ -41,6 +41,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
"TfLiteFakeQuantParams",
"TfLiteFullyConnectedParams",
"TfLiteGatherParams",
+ "TfLiteGeluParams",
"TfLiteIfParams",
"TfLiteL2NormParams",
"TfLiteLeakyReluParams",
@@ -205,6 +206,7 @@ class OpOptionData {
op_to_option_["IMAG"] = "";
op_to_option_["COMPLEX_ABS"] = "";
op_to_option_["BROADCAST_ARGS"] = "";
+ op_to_option_["GELU"] = "";
// TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index 21891754ccd..96f2261ea63 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -785,6 +785,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 3;
}
return 2;
+ case BuiltinOperator_GELU:
+ if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
+ op_sig.inputs.at(0).type == kTfLiteUInt8) {
+ return 2;
+ }
+ return 1;
default:
return 1;
}
diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc
index f605b976be8..47a92aa057d 100644
--- a/tensorflow/lite/tools/versioning/op_version_test.cc
+++ b/tensorflow/lite/tools/versioning/op_version_test.cc
@@ -1065,4 +1065,18 @@ TEST(OpVersionTest, VersioningBroadcastToTest) {
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
}
+TEST(OpVersionTest, VersioningGeluTest) {
+ OpSignature fake_op_sig;
+ fake_op_sig.op = BuiltinOperator_GELU;
+ fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+ fake_op_sig.op = BuiltinOperator_GELU;
+ fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig.op = BuiltinOperator_GELU;
+ fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8);
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+}
} // namespace tflite
diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc
index b82f0a2a748..c602a9d9119 100644
--- a/tensorflow/lite/tools/versioning/runtime_version.cc
+++ b/tensorflow/lite/tools/versioning/runtime_version.cc
@@ -360,6 +360,8 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_READ_VARIABLE, 1}, "2.6.0"},
{{BuiltinOperator_ASSIGN_VARIABLE, 1}, "2.6.0"},
{{BuiltinOperator_BROADCAST_ARGS, 1}, "2.6.0"},
+ {{BuiltinOperator_GELU, 1}, "2.9.0"},
+ {{BuiltinOperator_GELU, 2}, "2.9.0"},
});
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};