aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKarn Seth <karn@google.com>2021-01-13 19:17:50 +0000
committerKarn Seth <karn@google.com>2021-01-13 19:17:50 +0000
commit52c605f88b976d3ec386b09af0e72dec1e40d9a4 (patch)
treea3f26085acc24e53c31fb50ff3515ebffffb0dcc
parent884e999bde8f6c48e81c239eed95b7fcbaeb70ca (diff)
downloadprivate-join-and-compute-52c605f88b976d3ec386b09af0e72dec1e40d9a4.tar.gz
adds libraries for status testing, slight modifications to bignum
-rw-r--r--WORKSPACE16
-rw-r--r--crypto/big_num.cc21
-rw-r--r--crypto/big_num.h8
-rw-r--r--crypto/paillier.cc1
-rw-r--r--private_join_and_compute_rpc_impl.h2
-rw-r--r--util/BUILD110
-rw-r--r--util/file.cc76
-rw-r--r--util/file.h114
-rw-r--r--util/file_posix.cc167
-rw-r--r--util/file_test.cc148
-rw-r--r--util/file_test.proto23
-rw-r--r--util/proto_util.h52
-rw-r--r--util/proto_util_test.cc39
-rw-r--r--util/recordio.cc607
-rw-r--r--util/recordio.h273
-rw-r--r--util/recordio_test.cc508
-rw-r--r--util/status.inc2
-rw-r--r--util/status_macros.h15
-rw-r--r--util/status_matchers.h258
-rw-r--r--util/status_testing.h74
-rw-r--r--util/status_testing.inc17
21 files changed, 2517 insertions, 14 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 0382c17..33aa497 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -68,6 +68,22 @@ http_archive(
],
)
+# gtest.
+git_repository(
+ name = "com_github_google_googletest",
+ commit = "703bd9caab50b139428cea1aaff9974ebee5742e", # tag = "release-1.10.0"
+ remote = "https://github.com/google/googletest.git",
+ shallow_since = "1570114335 -0400",
+)
+
+# Protobuf
+git_repository(
+ name = "com_google_protobuf",
+ remote = "https://github.com/protocolbuffers/protobuf.git",
+ commit = "9647a7c2356a9529754c07235a2877ee676c2fd0",
+ shallow_since = "1609366209 -0800",
+)
+
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
# Includes boringssl, and other dependencies.
diff --git a/crypto/big_num.cc b/crypto/big_num.cc
index f886b0f..bba3f85 100644
--- a/crypto/big_num.cc
+++ b/crypto/big_num.cc
@@ -27,6 +27,23 @@
namespace private_join_and_compute {
+namespace {
+
+// Utility class for decimal string conversion.
+class BnString {
+ public:
+ explicit BnString(char* bn_char) : bn_char_(bn_char) {}
+
+ ~BnString() { OPENSSL_free(bn_char_); }
+
+ std::string ToString() { return std::string(bn_char_); }
+
+ private:
+ char* const bn_char_;
+};
+
+} // namespace
+
BigNum::BigNum(const BigNum& other)
: bn_(BignumPtr(CHECK_NOTNULL(BN_dup(other.bn_.get())))),
bn_ctx_(other.bn_ctx_) {}
@@ -90,6 +107,10 @@ StatusOr<uint64_t> BigNum::ToIntValue() const {
return val;
}
+std::string BigNum::ToDecimalString() const {
+ return BnString(BN_bn2dec(GetConstBignumPtr())).ToString();
+}
+
int BigNum::BitLength() const { return BN_num_bits(bn_.get()); }
bool BigNum::IsPrime(double prime_error_probability) const {
diff --git a/crypto/big_num.h b/crypto/big_num.h
index da96885..702f0de 100644
--- a/crypto/big_num.h
+++ b/crypto/big_num.h
@@ -19,6 +19,7 @@
#include <stdint.h>
#include <memory>
+#include <ostream>
#include <string>
#include "crypto/openssl.inc"
@@ -57,6 +58,9 @@ class ABSL_MUST_USE_RESULT BigNum {
// error code if the value of *this is larger than 64 bits.
StatusOr<uint64_t> ToIntValue() const;
+ // Returns a string representation of the BigNum as a decimal number.
+ std::string ToDecimalString() const;
+
// Returns the bit length of this BigNum.
int BitLength() const;
@@ -239,6 +243,10 @@ inline BigNum& operator>>=(BigNum& a, int n) { return a = a >> n; }
inline BigNum& operator<<=(BigNum& a, int n) { return a = a << n; }
+inline std::ostream& operator<<(std::ostream& strm, const BigNum& a) {
+ return strm << "BigNum(" << a.ToDecimalString() << ")";
+}
+
} // namespace private_join_and_compute
#endif // CRYPTO_BIG_NUM_H_
diff --git a/crypto/paillier.cc b/crypto/paillier.cc
index 39400b2..dcaaf5e 100644
--- a/crypto/paillier.cc
+++ b/crypto/paillier.cc
@@ -416,7 +416,6 @@ StatusOr<BigNum> PublicPaillier::EncryptUsingGeneratorAndRand(
return c.ModMul(g_n_to_r, modulus_);
}
-
StatusOr<BigNum> PublicPaillier::EncryptWithRand(const BigNum& m,
const BigNum& r) const {
if (r.Gcd(n_) != ctx_->One()) {
diff --git a/private_join_and_compute_rpc_impl.h b/private_join_and_compute_rpc_impl.h
index a5c68a6..5ae4bde 100644
--- a/private_join_and_compute_rpc_impl.h
+++ b/private_join_and_compute_rpc_impl.h
@@ -16,8 +16,6 @@
#ifndef OPEN_SOURCE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_
#define OPEN_SOURCE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_
-#define GLOG_NO_ABBREVIATED_SEVERITIES
-#include "glog/logging.h"
#include "include/grpcpp/grpcpp.h"
#include "include/grpcpp/server_context.h"
#include "include/grpcpp/support/status.h"
diff --git a/util/BUILD b/util/BUILD
index bd5cfef..9243349 100644
--- a/util/BUILD
+++ b/util/BUILD
@@ -14,6 +14,8 @@
# Build file for util folder in open-source Private Join and Compute.
+load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
+
package(
default_visibility = ["//visibility:public"],
features = [
@@ -23,13 +25,12 @@ package(
)
cc_library(
- name = "status",
- srcs = glob(
- ["*.cc"],
- ),
- hdrs = glob(["*.h"]),
+ name = "status_includes",
+ hdrs = [
+ "status.inc",
+ "status_macros.h",
+ ],
deps = [
- "@com_github_glog_glog//:glog",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_protobuf//:protobuf_lite",
@@ -37,12 +38,101 @@ cc_library(
)
cc_library(
- name = "status_includes",
- hdrs = ["status.inc"],
+ name = "status_testing_includes",
+ hdrs = [
+ "status_matchers.h",
+ "status_testing.h",
+ "status_testing.inc",
+ ],
+ deps = [
+ ":status_includes",
+ "@com_github_google_googletest//:gtest",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "file",
+ srcs = [
+ "file.cc",
+ "file_posix.cc",
+ ],
+ hdrs = [
+ "file.h",
+ ],
+ deps = [
+ ":status_includes",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "file_test",
+ size = "small",
+ srcs = [
+ "file_test.cc",
+ ],
+ deps = [
+ ":file",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+grpc_proto_library(
+ name = "file_test_proto",
+ srcs = ["file_test.proto"],
+)
+
+cc_library(
+ name = "proto_util",
+ hdrs = ["proto_util.h"],
deps = [
- ":status",
+ "@com_google_absl//absl/strings",
+ "@com_google_protobuf//:protobuf_lite",
+ ],
+)
+
+cc_test(
+ name = "proto_util_test",
+ size = "medium",
+ srcs = ["proto_util_test.cc"],
+ deps = [
+ ":file_test_proto",
+ ":proto_util",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "recordio",
+ srcs = [
+ "recordio.cc",
+ ],
+ hdrs = ["recordio.h"],
+ deps = [
+ ":file",
+ ":status_includes",
+ "@com_github_glog_glog//:glog",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
- "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
"@com_google_protobuf//:protobuf_lite",
],
)
+
+cc_test(
+ name = "recordio_test",
+ srcs = ["recordio_test.cc"],
+ deps = [
+ ":file_test_proto",
+ ":proto_util",
+ ":recordio",
+ ":status_includes",
+ ":status_testing_includes",
+ "//crypto:bn_util",
+ "@com_github_google_googletest//:gtest_main",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/util/file.cc b/util/file.cc
new file mode 100644
index 0000000..dae55be
--- /dev/null
+++ b/util/file.cc
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common implementations.
+
+#include "util/file.h"
+
+#include <sstream>
+
+namespace private_join_and_compute {
+namespace internal {
+namespace {
+
+bool IsAbsolutePath(absl::string_view path) {
+ return !path.empty() && path[0] == '/';
+}
+
+bool EndsWithSlash(absl::string_view path) {
+ return !path.empty() && path[path.size() - 1] == '/';
+}
+
+} // namespace
+
+std::string JoinPathImpl(std::initializer_list<std::string> paths) {
+ std::string joined_path;
+ int size = paths.size();
+
+ int counter = 1;
+ for (auto it = paths.begin(); it != paths.end(); ++it, ++counter) {
+ std::string path = *it;
+ if (path.empty()) {
+ continue;
+ }
+
+ if (it == paths.begin()) {
+ joined_path += path;
+ if (!EndsWithSlash(path)) {
+ joined_path += "/";
+ }
+ continue;
+ }
+
+ if (EndsWithSlash(path)) {
+ if (IsAbsolutePath(path)) {
+ joined_path += path.substr(1, path.size() - 2);
+ } else {
+ joined_path += path.substr(0, path.size() - 1);
+ }
+ } else {
+ if (IsAbsolutePath(path)) {
+ joined_path += path.substr(1);
+ } else {
+ joined_path += path;
+ }
+ }
+ if (counter != size) {
+ joined_path += ".";
+ }
+ }
+ return joined_path;
+}
+
+} // namespace internal
+} // namespace private_join_and_compute
diff --git a/util/file.h b/util/file.h
new file mode 100644
index 0000000..cb4911d
--- /dev/null
+++ b/util/file.h
@@ -0,0 +1,114 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// copybara:strip_begin(internal comment)
+// Abstract File class for different file systems. The implementation in
+// file_posix.cc is used on NaCl, WASM and Open source, and the implementation
+// in file_google3.cc is used on standard google3.
+// copybara:strip_end
+
+#ifndef INTERNAL_UTIL_FILE_H_
+#define INTERNAL_UTIL_FILE_H_
+
+#include <string>
+
+#include "util/status.inc"
+
+namespace private_join_and_compute {
+
+// Renames a file. Overwrites the new file if it exists.
+// Returns Status::OK for success.
+// Error code in case of an error depends on the underlying implementation.
+Status RenameFile(absl::string_view from, absl::string_view to);
+
+// Deletes a file.
+// Returns Status::OK for success.
+// Error code in case of an error depends on the underlying implementation.
+Status DeleteFile(absl::string_view file_name);
+
+class File {
+ public:
+ virtual ~File() = default;
+
+ // Opens the file_name for file operations applicable based on mode.
+ // Returns Status::OK for success.
+ // Error code in case of an error depends on the underlying implementation.
+ virtual Status Open(absl::string_view file_name, absl::string_view mode) = 0;
+
+ // Closes the opened file. Must be called after opening a file.
+ // Returns Status::OK for success.
+ // Error code in case of an error depends on the underlying implementation.
+ virtual Status Close() = 0;
+
+ // Returns true if there are more data in the file to be read.
+ // Returns a status instead in case of an io error in determining if there is
+ // more data.
+ virtual StatusOr<bool> HasMore() = 0;
+
+ // Returns a data string of size length from reading file if successful.
+ // Returns a status in case of an error.
+ // This would also return an error status if the read data size is less than
+ // the length since it indicates file corruption.
+ virtual StatusOr<std::string> Read(size_t length) = 0;
+
+ // Returns a line as string from the file without the trailing '\n' (or "\r\n"
+ // in the case of Windows).
+ //
+ // Returns a status in case of an error.
+ virtual StatusOr<std::string> ReadLine() = 0;
+
+ // Writes the given content of size length into the file.
+ // Error code in case of an error depends on the underlying implementation.
+ virtual Status Write(absl::string_view content, size_t length) = 0;
+
+ // Returns a File object depending on the linked implementation.
+ // Caller takes the ownership.
+ static File* GetFile();
+
+ protected:
+ File() = default;
+};
+
+namespace internal {
+std::string JoinPathImpl(std::initializer_list<std::string> paths);
+} // namespace internal
+
+// Joins multiple paths together such that only the first argument directory
+// structure is represented. A dot as a separator is added for other arguments.
+//
+// Arguments | JoinPath |
+// ---------------------------+---------------------+
+// '/foo', 'bar' | /foo/bar |
+// '/foo/', 'bar' | /foo/bar |
+// '/foo', '/bar' | /foo/bar |
+// '/foo', '/bar', '/baz' | /foo/bar.baz |
+//
+// All paths will be treated as relative paths, regardless of whether or not
+// they start with a leading '/'. That is, all paths will be concatenated
+// together, with the appropriate path separator inserted in between.
+// After the first path, all paths will be joined with a dot instead of the path
+// separator so that there is no level of directory after the first argument.
+// Arguments must be convertible to string.
+//
+// Usage:
+// string path = file::JoinPath("/tmp", dirname, filename);
+template <typename... T>
+std::string JoinPath(const T&... args) {
+ return internal::JoinPathImpl({args...});
+}
+
+} // namespace private_join_and_compute
+
+#endif // INTERNAL_UTIL_FILE_H_
diff --git a/util/file_posix.cc b/util/file_posix.cc
new file mode 100644
index 0000000..70806ff
--- /dev/null
+++ b/util/file_posix.cc
@@ -0,0 +1,167 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <limits.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "util/file.h"
+#include "util/status.inc"
+#include "absl/strings/str_cat.h"
+
+namespace private_join_and_compute {
+namespace {
+
+class PosixFile : public File {
+ public:
+ PosixFile() : File(), f_(nullptr), current_fname_() {}
+
+ ~PosixFile() override {
+ if (f_) Close().IgnoreError();
+ }
+
+ Status Open(absl::string_view file_name, absl::string_view mode) final {
+ if (nullptr != f_) {
+ return InternalError(
+ absl::StrCat("Open failed:", "File with name ", current_fname_,
+ " has already been opened, close it first."));
+ }
+ f_ = fopen(file_name.data(), mode.data());
+ if (nullptr == f_) {
+ return absl::NotFoundError(
+ absl::StrCat("Open failed:", "Error opening file ", file_name));
+ }
+ current_fname_ = std::string(file_name);
+ return OkStatus();
+ }
+
+ Status Close() final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("Close failed:", "There is no opened file."));
+ }
+ if (fclose(f_)) {
+ return InternalError(
+ absl::StrCat("Close failed:", "Error closing file ", current_fname_));
+ }
+ f_ = nullptr;
+ return OkStatus();
+ }
+
+ StatusOr<bool> HasMore() final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("HasMore failed:", "There is no opened file."));
+ }
+ if (feof(f_)) return false;
+ if (ferror(f_)) {
+ return InternalError(absl::StrCat(
+ "HasMore failed:", "Error indicator has been set for file ",
+ current_fname_));
+ }
+ int c = getc(f_);
+ if (ferror(f_)) {
+ return InternalError(absl::StrCat(
+ "HasMore failed:", "Error reading a single character from the file ",
+ current_fname_));
+ }
+ if (ungetc(c, f_) != c) {
+ return InternalError(absl::StrCat(
+ "HasMore failed:", "Error putting back the peeked character ",
+ "into the file ", current_fname_));
+ }
+ return c != EOF;
+ }
+
+ StatusOr<std::string> Read(size_t length) final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("Read failed:", "There is no opened file."));
+ }
+ std::vector<char> data(length);
+ if (fread(data.data(), 1, length, f_) != length) {
+ return InternalError(absl::StrCat(
+ "condition failed:", "Error reading the file ", current_fname_));
+ }
+ return std::string(data.begin(), data.end());
+ }
+
+ StatusOr<std::string> ReadLine() final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("ReadLine failed:", "There is no opened file."));
+ }
+ if (fgets(buffer_, LINE_MAX, f_) == nullptr || ferror(f_)) {
+ return InternalError(
+ absl::StrCat("ReadLine failed:", "Error reading line from the file ",
+ current_fname_));
+ }
+ std::string content;
+ int len = strlen(buffer_);
+ // Remove trailing '\n' if present.
+ if (len > 0 && buffer_[len - 1] == '\n') {
+ // Remove trailing '\r' if present (e.g. on Windows)
+ if (len > 1 && buffer_[len - 2] == '\r') {
+ content.append(buffer_, len - 2);
+ } else {
+ content.append(buffer_, len - 1);
+ }
+ } else {
+ // No trailing newline characters
+ content.append(buffer_, len);
+ }
+ return content;
+ }
+
+ Status Write(absl::string_view content, size_t length) final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("ReadLine failed:", "There is no opened file."));
+ }
+ if (fwrite(content.data(), 1, length, f_) != length) {
+ return InternalError(absl::StrCat(
+ "ReadLine failed:", "Error writing the given data into the file ",
+ current_fname_));
+ }
+ return OkStatus();
+ }
+
+ private:
+ FILE* f_;
+ std::string current_fname_;
+ char buffer_[LINE_MAX];
+};
+
+} // namespace
+
+File* File::GetFile() { return new PosixFile(); }
+
+Status RenameFile(absl::string_view from, absl::string_view to) {
+ if (0 != rename(from.data(), to.data())) {
+ return InternalError(absl::StrCat(
+ "RenameFile failed:", "Cannot rename file, ", from, " to file, ", to));
+ }
+ return OkStatus();
+}
+
+Status DeleteFile(absl::string_view file_name) {
+ if (0 != remove(file_name.data())) {
+ return InternalError(
+ absl::StrCat("DeleteFile failed:", "Cannot delete file, ", file_name));
+ }
+ return OkStatus();
+}
+
+} // namespace private_join_and_compute
diff --git a/util/file_test.cc b/util/file_test.cc
new file mode 100644
index 0000000..291f3c9
--- /dev/null
+++ b/util/file_test.cc
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/file.h"
+
+#include "util/status.inc"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace private_join_and_compute {
+namespace {
+
+template <typename T1, typename T2>
+void AssertOkAndHolds(const T1& expected_value, const StatusOr<T2>& status_or) {
+ EXPECT_TRUE(status_or.ok()) << status_or.status();
+ EXPECT_EQ(expected_value, status_or.value());
+}
+
+class FileTest : public testing::Test {
+ public:
+ FileTest() : testing::Test(), f_(File::GetFile()) {}
+
+ std::unique_ptr<File> f_;
+};
+
+TEST_F(FileTest, WriteDataThenReadTest) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok());
+ EXPECT_TRUE(f_->Write("water", 4).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "rb").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("wat", f_->Read(3));
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("e", f_->Read(1));
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, ReadLineTest) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok());
+ EXPECT_TRUE(f_->Write("Line1\nLine2\n\n", 13).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line1", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line2", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("", f_->ReadLine());
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, CannotOpenFileIfAnotherFileIsAlreadyOpened) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp1.txt", "w").ok());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFile) {
+ EXPECT_FALSE(f_->Close().ok());
+ EXPECT_FALSE(f_->HasMore().ok());
+ EXPECT_FALSE(f_->Read(1).ok());
+ EXPECT_FALSE(f_->ReadLine().ok());
+ EXPECT_FALSE(f_->Write("w", 1).ok());
+}
+
+TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFileAfterClosing) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_FALSE(f_->Close().ok());
+ EXPECT_FALSE(f_->HasMore().ok());
+ EXPECT_FALSE(f_->Read(1).ok());
+ EXPECT_FALSE(f_->ReadLine().ok());
+ EXPECT_FALSE(f_->Write("w", 1).ok());
+}
+
+TEST_F(FileTest, TestRename) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_TRUE(f_->Write("water", 5).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(RenameFile(testing::TempDir() + "/tmp.txt",
+ testing::TempDir() + "/tmp1.txt")
+ .ok());
+ EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp1.txt", "r").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("water", f_->Read(5));
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, TestDelete) {
+ // Create file and delete it.
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_TRUE(f_->Write("water", 5).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(DeleteFile(testing::TempDir() + "/tmp.txt").ok());
+ EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+
+ // Try to delete nonexistent file.
+ EXPECT_FALSE(DeleteFile(testing::TempDir() + "/tmp2.txt").ok());
+}
+
+TEST_F(FileTest, JoinPathWithMultipleArgs) {
+ std::string ret = JoinPath("/tmp", "foo", "bar/", "/baz/");
+ EXPECT_EQ("/tmp/foo.bar.baz", ret);
+}
+
+TEST_F(FileTest, JoinPathWithMultipleArgsStartingWithEndSlashDir) {
+ std::string ret = JoinPath("/tmp/", "foo", "bar/", "/baz/");
+ EXPECT_EQ("/tmp/foo.bar.baz", ret);
+}
+
+TEST_F(FileTest, ReadLineWithCarriageReturnsTest) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok());
+ std::string file_string = "Line1\nLine2\r\nLine3\r\nLine4\n\n";
+ EXPECT_TRUE(f_->Write(file_string, file_string.size()).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line1", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line2", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line3", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line4", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("", f_->ReadLine());
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/util/file_test.proto b/util/file_test.proto
new file mode 100644
index 0000000..3bf8e96
--- /dev/null
+++ b/util/file_test.proto
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package private_join_and_compute.testing;
+
+message TestProto {
+ optional bytes record = 1;
+ optional bytes dummy = 2;
+}
diff --git a/util/proto_util.h b/util/proto_util.h
new file mode 100644
index 0000000..8239751
--- /dev/null
+++ b/util/proto_util.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Protocol buffer related static utility functions.
+
+#ifndef INTERNAL_UTIL_PROTO_UTIL_H_
+#define INTERNAL_UTIL_PROTO_UTIL_H_
+
+#include <sstream>
+#include <string>
+
+#include "src/google/protobuf/message_lite.h"
+#include "absl/strings/string_view.h"
+
+namespace private_join_and_compute {
+
+class ProtoUtils {
+ public:
+ template <typename ProtoType>
+ static ProtoType FromString(absl::string_view raw_data);
+
+ static std::string ToString(const google::protobuf::MessageLite& record);
+};
+
+template <typename ProtoType>
+inline ProtoType ProtoUtils::FromString(absl::string_view raw_data) {
+ ProtoType record;
+ record.ParseFromArray(raw_data.data(), raw_data.size());
+ return record;
+}
+
+inline std::string ProtoUtils::ToString(const google::protobuf::MessageLite& record) {
+ std::ostringstream record_str_stream;
+ record.SerializeToOstream(&record_str_stream);
+ return record_str_stream.str();
+}
+
+} // namespace private_join_and_compute
+
+#endif // INTERNAL_UTIL_PROTO_UTIL_H_
diff --git a/util/proto_util_test.cc b/util/proto_util_test.cc
new file mode 100644
index 0000000..ecb486d
--- /dev/null
+++ b/util/proto_util_test.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/proto_util.h"
+
+#include "util/file_test.pb.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace private_join_and_compute {
+
+namespace {
+using testing::TestProto;
+
+TEST(ProtoUtilsTest, ConvertsToAndFrom) {
+ TestProto expected_test_proto;
+ expected_test_proto.set_record("data");
+ expected_test_proto.set_dummy("dummy");
+ std::string serialized = ProtoUtils::ToString(expected_test_proto);
+ TestProto actual_test_proto = ProtoUtils::FromString<TestProto>(serialized);
+ EXPECT_EQ(actual_test_proto.record(), expected_test_proto.record());
+ EXPECT_EQ(actual_test_proto.dummy(), expected_test_proto.dummy());
+}
+
+} // namespace
+
+} // namespace private_join_and_compute
diff --git a/util/recordio.cc b/util/recordio.cc
new file mode 100644
index 0000000..d509632
--- /dev/null
+++ b/util/recordio.cc
@@ -0,0 +1,607 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/recordio.h"
+
+#include <algorithm>
+#include <functional>
+#include <list>
+#include <memory>
+#include <queue>
+#include <string>
+#include <vector>
+
+#define GLOG_NO_ABBREVIATED_SEVERITIES
+#include "glog/logging.h"
+#include "src/google/protobuf/io/coded_stream.h"
+#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h"
+#include "util/status.inc"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+
+namespace private_join_and_compute {
+
+namespace {
+
+// Max. size of a Varint32 (from proto references).
+const uint32_t kMaxVarint32Size = 5;
+
+// Tries to read a Varint32 from the front of a given file. Returns false if the
+// reading fails.
+StatusOr<uint32_t> ExtractVarint32(File* file) {
+ // Keep reading a single character until one is found such that the top bit is
+ // 0;
+ std::string bytes_read = "";
+
+ size_t current_byte = 0;
+ ASSIGN_OR_RETURN(auto has_more, file->HasMore());
+ while (current_byte < kMaxVarint32Size && has_more) {
+ auto maybe_last_byte = file->Read(1);
+ if (!maybe_last_byte.ok()) {
+ return maybe_last_byte.status();
+ }
+
+ bytes_read += maybe_last_byte.value();
+ if (!(bytes_read.data()[current_byte] & 0x80)) {
+ break;
+ }
+ current_byte++;
+ // If we read the max number of bits and never found a "terminating" byte,
+ // return false.
+ if (current_byte >= kMaxVarint32Size) {
+ return InvalidArgumentError(
+ "ExtractVarint32: Failed to extract a Varint after reading max "
+ "number "
+ "of bytes.");
+ }
+ ASSIGN_OR_RETURN(has_more, file->HasMore());
+ }
+
+ google::protobuf::io::ArrayInputStream arrayInputStream(bytes_read.data(),
+ bytes_read.size());
+ google::protobuf::io::CodedInputStream codedInputStream(&arrayInputStream);
+ uint32_t result;
+ codedInputStream.ReadVarint32(&result);
+
+ return result;
+}
+
+// Reads records from a file one at a time.
+class RecordReaderImpl : public RecordReader {
+ public:
+ explicit RecordReaderImpl(File* file) : RecordReader(), in_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return in_->Open(filename, "r");
+ }
+
+ Status Close() final { return in_->Close(); }
+
+ StatusOr<bool> HasMore() final {
+ auto status_or_has_more = in_->HasMore();
+ if (!status_or_has_more.ok()) {
+ LOG(ERROR) << status_or_has_more.status();
+ }
+ return status_or_has_more;
+ }
+
+ Status Read(std::string* raw_data) final {
+ raw_data->erase();
+ auto maybe_record_size = ExtractVarint32(in_.get());
+ if (!maybe_record_size.ok()) {
+ LOG(ERROR) << "RecordReader::Read: Couldn't read record size: "
+ << maybe_record_size.status();
+ return maybe_record_size.status();
+ }
+ uint32_t record_size = maybe_record_size.value();
+
+ auto status_or_data = in_->Read(record_size);
+ if (!status_or_data.ok()) {
+ LOG(ERROR) << status_or_data.status();
+ return status_or_data.status();
+ }
+
+ raw_data->append(status_or_data.value());
+ return OkStatus();
+ }
+
+ private:
+ std::unique_ptr<File> in_;
+};
+
+// Reads lines from a file one at a time.
+class LineReader : public RecordReader {
+ public:
+ explicit LineReader(File* file) : RecordReader(), in_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return in_->Open(filename, "r");
+ }
+
+ Status Close() final { return in_->Close(); }
+
+ StatusOr<bool> HasMore() final { return in_->HasMore(); }
+
+ Status Read(std::string* line) final {
+ line->erase();
+ auto status_or_line = in_->ReadLine();
+ if (!status_or_line.ok()) {
+ LOG(ERROR) << status_or_line.status();
+ return status_or_line.status();
+ }
+ line->append(status_or_line.value());
+ return OkStatus();
+ }
+
+ private:
+ std::unique_ptr<File> in_;
+};
+
+template <typename T>
+class MultiSortedReaderImpl : public MultiSortedReader<T> {
+ public:
+ explicit MultiSortedReaderImpl(
+ const std::function<RecordReader*()>& get_reader,
+ std::unique_ptr<std::function<T(absl::string_view)>> default_key =
+ nullptr)
+ : MultiSortedReader<T>(),
+ get_reader_(get_reader),
+ default_key_(std::move(default_key)),
+ key_(nullptr) {}
+
+ Status Open(const std::vector<std::string>& filenames) override {
+ if (default_key_ == nullptr) {
+ return InvalidArgumentError("The sorting key is null.");
+ }
+ return Open(filenames, *default_key_);
+ }
+
+ Status Open(const std::vector<std::string>& filenames,
+ const std::function<T(absl::string_view)>& key) override {
+ if (!readers_.empty()) {
+ return InternalError("There are files not closed, call Close() first.");
+ }
+ key_ = absl::make_unique<std::function<T(absl::string_view)>>(key);
+ for (size_t i = 0; i < filenames.size(); ++i) {
+ this->readers_.push_back(std::unique_ptr<RecordReader>(get_reader_()));
+ auto open_status = this->readers_.back()->Open(filenames[i]);
+ if (!open_status.ok()) {
+ // Try to close the opened ones.
+ for (int j = i - 1; j >= 0; --j) {
+ // If closing fails as well, then any call to Open will fail as well
+ // since some of the files will remain opened.
+ auto status = this->readers_[j]->Close();
+ if (!status.ok()) {
+ LOG(ERROR) << "Error closing file " << status;
+ }
+ this->readers_.pop_back();
+ }
+ return open_status;
+ }
+ }
+ return OkStatus();
+ }
+
+ Status Close() override {
+ Status status = OkStatus();
+ bool ret_val =
+ std::all_of(readers_.begin(), readers_.end(),
+ [&status](std::unique_ptr<RecordReader>& reader) {
+ Status close_status = reader->Close();
+ if (!close_status.ok()) {
+ status = close_status;
+ return false;
+ } else {
+ return true;
+ }
+ });
+ if (ret_val) {
+ readers_ = std::vector<std::unique_ptr<RecordReader>>();
+ min_heap_ = std::priority_queue<HeapData, std::vector<HeapData>,
+ HeapDataGreater>();
+ }
+ return status;
+ }
+
+ StatusOr<bool> HasMore() override {
+ if (!min_heap_.empty()) {
+ return true;
+ }
+ Status status = OkStatus();
+ for (const auto& reader : readers_) {
+ auto status_or_has_more = reader->HasMore();
+ if (status_or_has_more.ok()) {
+ if (status_or_has_more.value()) {
+ return true;
+ }
+ } else {
+ status = status_or_has_more.status();
+ }
+ }
+ if (status.ok()) {
+ // None of the readers has more.
+ return false;
+ }
+ return status;
+ }
+
+ Status Read(std::string* data) override { return Read(data, nullptr); }
+
+ Status Read(std::string* data, int* index) override {
+ if (min_heap_.empty()) {
+ for (size_t i = 0; i < readers_.size(); ++i) {
+ RETURN_IF_ERROR(this->ReadHeapDataFromReader(i));
+ }
+ }
+ HeapData ret_data = min_heap_.top();
+ data->assign(ret_data.data);
+ if (index != nullptr) *index = ret_data.index;
+ min_heap_.pop();
+ return this->ReadHeapDataFromReader(ret_data.index);
+ }
+
+ private:
+ Status ReadHeapDataFromReader(int index) {
+ std::string data;
+ auto status_or_has_more = readers_[index]->HasMore();
+ if (!status_or_has_more.ok()) {
+ return status_or_has_more.status();
+ }
+ if (status_or_has_more.value()) {
+ RETURN_IF_ERROR(readers_[index]->Read(&data));
+ HeapData heap_data;
+ heap_data.key = (*key_)(data);
+ heap_data.data = data;
+ heap_data.index = index;
+ min_heap_.push(heap_data);
+ }
+ return OkStatus();
+ }
+
+ struct HeapData {
+ T key;
+ std::string data;
+ int index;
+ };
+
+ struct HeapDataGreater {
+ bool operator()(const HeapData& lhs, const HeapData& rhs) const {
+ return lhs.key > rhs.key;
+ }
+ };
+
+ const std::function<RecordReader*()> get_reader_;
+ std::unique_ptr<std::function<T(absl::string_view)>> default_key_;
+ std::unique_ptr<std::function<T(absl::string_view)>> key_;
+ std::vector<std::unique_ptr<RecordReader>> readers_;
+ std::priority_queue<HeapData, std::vector<HeapData>, HeapDataGreater>
+ min_heap_;
+};
+
+// Writes records to a file one at a time.
+class RecordWriterImpl : public RecordWriter {
+ public:
+ explicit RecordWriterImpl(File* file) : RecordWriter(), out_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return out_->Open(filename, "w");
+ }
+
+ Status Close() final { return out_->Close(); }
+
+ Status Write(absl::string_view raw_data) final {
+ std::string delimited_output;
+ auto string_output =
+ absl::make_unique<google::protobuf::io::StringOutputStream>(&delimited_output);
+ auto coded_output =
+ absl::make_unique<google::protobuf::io::CodedOutputStream>(string_output.get());
+
+ // Write the delimited output.
+ coded_output->WriteVarint32(raw_data.size());
+ coded_output->WriteString(std::string(raw_data));
+
+ // Force the serialization, which makes delimited_output safe to read.
+ coded_output = nullptr;
+ string_output = nullptr;
+
+ return out_->Write(delimited_output, delimited_output.size());
+ }
+
+ private:
+ std::unique_ptr<File> out_;
+};
+
+// Writes lines to a file one at a time.
+class LineWriterImpl : public LineWriter {
+ public:
+ explicit LineWriterImpl(File* file) : LineWriter(), out_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return out_->Open(filename, "w");
+ }
+
+ Status Close() final { return out_->Close(); }
+
+ Status Write(absl::string_view line) final {
+ RETURN_IF_ERROR(out_->Write(line.data(), line.size()));
+ return out_->Write("\n", 1);
+ }
+
+ private:
+ std::unique_ptr<File> out_;
+};
+
+} // namespace
+
+RecordReader* RecordReader::GetLineReader() {
+ return RecordReader::GetLineReader(File::GetFile());
+}
+
+RecordReader* RecordReader::GetLineReader(File* file) {
+ return new LineReader(file);
+}
+
+RecordReader* RecordReader::GetRecordReader() {
+ return RecordReader::GetRecordReader(File::GetFile());
+}
+
+RecordReader* RecordReader::GetRecordReader(File* file) {
+ return new RecordReaderImpl(file);
+}
+
+RecordWriter* RecordWriter::Get() { return RecordWriter::Get(File::GetFile()); }
+
+RecordWriter* RecordWriter::Get(File* file) {
+ return new RecordWriterImpl(file);
+}
+
+LineWriter* LineWriter::Get() { return LineWriter::Get(File::GetFile()); }
+
+LineWriter* LineWriter::Get(File* file) { return new LineWriterImpl(file); }
+
+template <typename T>
+MultiSortedReader<T>* MultiSortedReader<T>::Get() {
+ return MultiSortedReader<T>::Get(
+ []() { return RecordReader::GetRecordReader(); });
+}
+
+template <>
+MultiSortedReader<std::string>* MultiSortedReader<std::string>::Get(
+ const std::function<RecordReader*()>& get_reader) {
+ return new MultiSortedReaderImpl<std::string>(
+ get_reader,
+ absl::make_unique<std::function<std::string(absl::string_view)>>(
+ [](absl::string_view s) { return std::string(s); }));
+}
+
+template <>
+MultiSortedReader<int64_t>* MultiSortedReader<int64_t>::Get(
+ const std::function<RecordReader*()>& get_reader) {
+ return new MultiSortedReaderImpl<int64_t>(
+ get_reader, absl::make_unique<std::function<int64_t(absl::string_view)>>(
+ [](absl::string_view s) { return 0; }));
+}
+
+template class MultiSortedReader<int64_t>;
+template class MultiSortedReader<std::string>;
+
+namespace {
+
+std::string GetFilename(absl::string_view prefix, int32_t idx) {
+ return absl::StrCat(prefix, idx);
+}
+
+template <typename T>
+class ShardingWriterImpl : public ShardingWriter<T> {
+ public:
+ static Status AlreadyUnhealthyError() {
+ return InternalError("ShardingWriter: Already unhealthy.");
+ }
+
+ explicit ShardingWriterImpl(
+ const std::function<T(absl::string_view)>& get_key,
+ int32_t max_bytes = 209715200, /* 200MB */
+ std::unique_ptr<RecordWriter> record_writer =
+ absl::WrapUnique(RecordWriter::Get()))
+ : get_key_(get_key),
+ record_writer_(std::move(record_writer)),
+ max_bytes_(max_bytes),
+ cache_(),
+ bytes_written_(0),
+ current_file_idx_(0),
+ shard_files_(),
+ healthy_(true),
+ open_(false) {}
+
+ void SetShardPrefix(absl::string_view shard_prefix) override {
+ absl::MutexLock lock(&mutex_);
+ open_ = true;
+ fnames_prefix_ = std::string(shard_prefix);
+ current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
+ }
+
+ StatusOr<std::vector<std::string>> Close() override {
+ absl::MutexLock lock(&mutex_);
+
+ auto retval = TryClose();
+
+ // Guarantee that the state is reset, even if TryClose fails.
+ fnames_prefix_ = "";
+ current_fname_ = "";
+ healthy_ = true;
+ cache_.clear();
+ bytes_written_ = 0;
+ shard_files_.clear();
+ current_file_idx_ = 0;
+ open_ = false;
+
+ return retval;
+ }
+
+ // Writes the supplied Record into the file.
+ // Returns true if the write operation was successful.
+ Status Write(absl::string_view raw_record) override {
+ absl::MutexLock lock(&mutex_);
+ if (!open_) {
+ return InternalError("Must call SetShardPrefix before calling Write.");
+ }
+ if (!healthy_) {
+ return AlreadyUnhealthyError();
+ }
+ if (bytes_written_ > max_bytes_) {
+ RETURN_IF_ERROR(WriteCacheToFile());
+ }
+ bytes_written_ += raw_record.size();
+ cache_.push_back(std::string(raw_record));
+ return OkStatus();
+ }
+
+ private:
+ Status WriteCacheToFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ if (!healthy_) return AlreadyUnhealthyError();
+ if (cache_.empty()) return OkStatus();
+ cache_.sort([this](absl::string_view r1, absl::string_view r2) {
+ return get_key_(r1) < get_key_(r2);
+ });
+ if (!record_writer_->Open(current_fname_).ok()) {
+ healthy_ = false;
+ return InternalError(
+ absl::StrCat("Cannot open ", current_fname_, " for writing."));
+ }
+ Status status = absl::OkStatus();
+ for (absl::string_view r : cache_) {
+ if (!record_writer_->Write(r).ok()) {
+ healthy_ = false;
+ status = InternalError(
+ absl::StrCat("Cannot write record ", r, " to ", current_fname_));
+
+ break;
+ }
+ }
+ if (!record_writer_->Close().ok()) {
+ if (status.ok()) {
+ status =
+ InternalError(absl::StrCat("Cannot close ", current_fname_, "."));
+ } else {
+ // Preserve the old status message.
+ LOG(WARNING) << "Cannot close " << current_fname_;
+ }
+ }
+
+ shard_files_.push_back(current_fname_);
+ cache_.clear();
+ bytes_written_ = 0;
+ ++current_file_idx_;
+ current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
+ return status;
+ }
+
+ StatusOr<std::vector<std::string>> TryClose()
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ if (!open_) {
+ return InternalError("Must call SetShardPrefix before calling Close.");
+ }
+ RETURN_IF_ERROR(WriteCacheToFile());
+
+ return {shard_files_};
+ }
+
+ absl::Mutex mutex_;
+ std::function<T(absl::string_view)> get_key_;
+ std::unique_ptr<RecordWriter> record_writer_ ABSL_GUARDED_BY(mutex_);
+ std::string fnames_prefix_ ABSL_GUARDED_BY(mutex_);
+ const int32_t max_bytes_ ABSL_GUARDED_BY(mutex_);
+ std::list<std::string> cache_ ABSL_GUARDED_BY(mutex_);
+ int32_t bytes_written_ ABSL_GUARDED_BY(mutex_);
+ int32_t current_file_idx_ ABSL_GUARDED_BY(mutex_);
+ std::string current_fname_ ABSL_GUARDED_BY(mutex_);
+ std::vector<std::string> shard_files_ ABSL_GUARDED_BY(mutex_);
+ bool healthy_ ABSL_GUARDED_BY(mutex_);
+ bool open_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace
+
+template <typename T>
+std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
+ const std::function<T(absl::string_view)>& get_key, int32_t max_bytes) {
+ return absl::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes);
+}
+
+// Test only.
+template <typename T>
+std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
+ const std::function<T(absl::string_view)>& get_key, int32_t max_bytes,
+ std::unique_ptr<RecordWriter> record_writer) {
+ return absl::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes,
+ std::move(record_writer));
+}
+
+template class ShardingWriter<int64_t>;
+template class ShardingWriter<std::string>;
+
+template <typename T>
+ShardMerger<T>::ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader,
+ std::unique_ptr<RecordWriter> writer)
+ : multi_reader_(std::move(multi_reader)), writer_(std::move(writer)) {}
+
+template <typename T>
+Status ShardMerger<T>::Merge(const std::function<T(absl::string_view)>& get_key,
+ const std::vector<std::string>& shard_files,
+ absl::string_view output_file) {
+ if (shard_files.empty()) {
+ // Create an empty output file.
+ RETURN_IF_ERROR(writer_->Open(output_file));
+ RETURN_IF_ERROR(writer_->Close());
+ }
+
+ // Multi-sorted-read all shards, and write the results to the supplied file.
+ std::vector<std::string> converted_shard_files;
+ converted_shard_files.reserve(shard_files.size());
+ for (const auto& filename : shard_files) {
+ converted_shard_files.push_back(filename);
+ }
+
+ RETURN_IF_ERROR(multi_reader_->Open(converted_shard_files, get_key));
+
+ RETURN_IF_ERROR(writer_->Open(output_file));
+
+ for (std::string record; multi_reader_->HasMore().value();) {
+ RETURN_IF_ERROR(multi_reader_->Read(&record));
+ RETURN_IF_ERROR(writer_->Write(record));
+ }
+ RETURN_IF_ERROR(writer_->Close());
+
+ RETURN_IF_ERROR(multi_reader_->Close());
+
+ return OkStatus();
+}
+
+template <typename T>
+Status ShardMerger<T>::Delete(std::vector<std::string> shard_files) {
+ for (const auto& filename : shard_files) {
+ RETURN_IF_ERROR(DeleteFile(filename));
+ }
+
+ return OkStatus();
+}
+
+template class ShardMerger<int64_t>;
+template class ShardMerger<std::string>;
+
+} // namespace private_join_and_compute
diff --git a/util/recordio.h b/util/recordio.h
new file mode 100644
index 0000000..32d11b4
--- /dev/null
+++ b/util/recordio.h
@@ -0,0 +1,273 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Defines file operations.
+// RecordWriter generates output records that are binary data preceded with a
+// Varint that explains the size of the records. The records provided to
+// RecordWriter can be arbitrary binary data, but usually they will be
+// serialized protobufs.
+//
+// copybara:strip_begin(internal comment)
+// Note that this library is not in any way compatible with the Google3 RecordIo
+// class, but rather uses the suggested proto-serialization described in
+// https://developers.google.com/protocol-buffers/docs/techniques#streaming
+// This encoding is also compatible with Java parseDelimitedFrom and
+// writeDelimitedTo.
+// copybara:strip_end
+//
+// RecordReader reads files written in the above format, and is also compatible
+// with files written using the Java version of parseDelimitedFrom and
+// writeDelimitedTo.
+//
+// LineWriter writes single lines to the output file. LineReader reads single
+// lines from the input file.
+//
+// Note that all classes except ShardingWriter are not thread-safe: concurrent
+// accesses must be protected by mutexes.
+
+#ifndef INTERNAL_UTIL_RECORDIO_H_
+#define INTERNAL_UTIL_RECORDIO_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "util/file.h"
+#include "util/status.inc"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+
+namespace private_join_and_compute {
+
+// Interface for reading a single file.
+class RecordReader {
+ public:
+ virtual ~RecordReader() = default;
+
+ // RecordReader is neither copyable nor movable.
+ RecordReader(const RecordReader&) = delete;
+ RecordReader& operator=(const RecordReader&) = delete;
+
+ // Opens the given file for reading.
+ virtual Status Open(absl::string_view file_name) = 0;
+
+ // Closes any file object created via calling SingleFileReader::Open
+ virtual Status Close() = 0;
+
+ // Returns true if there are more records in the file to be read.
+ virtual StatusOr<bool> HasMore() = 0;
+
+ // Reads a record from the file (line or binary record).
+ virtual Status Read(std::string* record) = 0;
+
+ // Returns a RecordReader for reading files line by line.
+ // Caller takes the ownership.
+ static RecordReader* GetLineReader();
+
+ // Returns a RecordReader for reading files in a record format compatible with
+ // RecordWriter below.
+ // Caller takes the ownership.
+ static RecordReader* GetRecordReader();
+
+ // Test only.
+ static RecordReader* GetLineReader(File* file);
+ static RecordReader* GetRecordReader(File* file);
+
+ protected:
+ RecordReader() = default;
+};
+
+// Reads records one at a time in ascending order from multiple files, assuming
+// each file stores records in ascending order. This class does the merge step
+// for the external sorting. Templates T supported are string and int64.
+template <typename T>
+class MultiSortedReader {
+ public:
+ virtual ~MultiSortedReader() = default;
+
+ // MultiSortedReader is neither copyable nor movable.
+ MultiSortedReader(const MultiSortedReader&) = delete;
+ MultiSortedReader& operator=(const MultiSortedReader&) = delete;
+
+ // Opens the files generated with RecordWriterInterface. Records in each file
+ // are assumed to be sorted beforehand.
+ virtual Status Open(const std::vector<std::string>& filenames) = 0;
+
+ // Same as Open above but also accepts a key function that is used to convert
+ // a string record into a value of type T, used when comparing the records.
+ // Records will be read from the file heads in ascending order of "key".
+ virtual Status Open(const std::vector<std::string>& filenames,
+ const std::function<T(absl::string_view)>& key) = 0;
+
+ // Closes the file streams.
+ virtual Status Close() = 0;
+
+ // Returns true if there are more records in the file to be read.
+ virtual StatusOr<bool> HasMore() = 0;
+
+ // Reads a record data into <code>data</code> in ascending order.
+ // Erases the <code>data</code> before writing to it.
+ virtual Status Read(std::string* data) = 0;
+
+ // Same as Read(string* data) but this also puts the index of the file
+ // where the data has been read from if index is not nullptr.
+ // Erases the <code>data</code> before writing to it.
+ virtual Status Read(std::string* data, int* index) = 0;
+
+ // Returns a MultiSortedReader.
+ // Caller takes the ownership.
+ static MultiSortedReader<T>* Get();
+
+ // Test only.
+ static MultiSortedReader* Get(
+ const std::function<RecordReader*()>& get_reader);
+
+ protected:
+ MultiSortedReader() = default;
+};
+
+class RecordWriter {
+ public:
+ virtual ~RecordWriter() = default;
+
+ // RecordWriter is neither copyable nor movable.
+ RecordWriter(const RecordWriter&) = delete;
+ RecordWriter& operator=(const RecordWriter&) = delete;
+
+ // Opens the given file for writing records.
+ virtual Status Open(absl::string_view file_name) = 0;
+
+ // Closes the file stream and returns true if successful.
+ virtual Status Close() = 0;
+
+ // Writes <code>raw_data</code> into the file as-is, with a delimiter
+ // specifying the data size.
+ virtual Status Write(absl::string_view raw_data) = 0;
+
+ // Returns a RecordWriter.
+ // Caller takes the ownership.
+ static RecordWriter* Get();
+
+ // Test only.
+ static RecordWriter* Get(File* file);
+
+ protected:
+ RecordWriter() = default;
+};
+
+class LineWriter {
+ public:
+ virtual ~LineWriter() = default;
+
+ // LineWriter is neither copyable nor movable.
+ LineWriter(const LineWriter&) = delete;
+ LineWriter& operator=(const LineWriter&) = delete;
+
+ // Opens the given file for writing lines.
+ virtual Status Open(absl::string_view file_name) = 0;
+
+ // Closes the file stream and returns OkStatus if successful.
+ virtual Status Close() = 0;
+
+ // Writes <code>line</code> into the file, with a trailing newline.
+ // Returns OkStatus if the write operation was successful.
+ virtual Status Write(absl::string_view line) = 0;
+
+ // Returns a RecordWriter.
+ // Caller takes the ownership.
+ static LineWriter* Get();
+
+ // Test only.
+ static LineWriter* Get(File* file);
+
+ protected:
+ LineWriter() = default;
+};
+
+// Writes Records to shard files, with each shard file internally sorted based
+// on the supplied get_key method.
+//
+// This class is thread-safe.
+template <typename T>
+class ShardingWriter {
+ public:
+ virtual ~ShardingWriter() = default;
+
+ // ShardingWriter is neither copyable nor copy-assignable.
+ ShardingWriter(const ShardingWriter&) = delete;
+ ShardingWriter& operator=(const ShardingWriter&) = delete;
+
+ // Shards will be created with the supplied prefix. Must be called before
+ // Write.
+ virtual void SetShardPrefix(absl::string_view shard_prefix) = 0;
+
+ // Clears the remaining cache, and returns the list of all shard files that
+ // were written since the last call to SetShardPrefix. Caller is responsible
+ // for merging and deleting shards.
+ //
+ // Returns InternalError if clearing the remaining cache fails.
+ virtual StatusOr<std::vector<std::string>> Close() = 0;
+
+ // Writes the supplied str into the file.
+ // Implementations need not actually write the record on each call. Rather,
+ // they may cache records until max_bytes records have been cached, at which
+ // point they may sort the cache and write it to a shard file.
+ //
+ // Implementations must return InternalError if writing the cache fails, or
+ // if the shard prefix has not been set.
+ virtual Status Write(absl::string_view raw_data) = 0;
+
+ // Returns a ShardingWriter that uses the supplied key to compare records.
+ // @param max_bytes: denotes the maximum size of each shard to write.
+ static std::unique_ptr<ShardingWriter> Get(
+ const std::function<T(absl::string_view)>& get_key,
+ int32_t max_bytes = 209715200 /* 200MB */);
+
+ // Test only.
+ static std::unique_ptr<ShardingWriter> Get(
+ const std::function<T(absl::string_view)>& get_key, int32_t max_bytes,
+ std::unique_ptr<RecordWriter> record_writer);
+
+ protected:
+ ShardingWriter() = default;
+};
+
+// Utility class to allow merging of sorted shards, and deleting of shards.
+template <typename T>
+class ShardMerger {
+ public:
+ explicit ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader =
+ absl::WrapUnique(MultiSortedReader<T>::Get()),
+ std::unique_ptr<RecordWriter> writer =
+ absl::WrapUnique(RecordWriter::Get()));
+
+ // Merges the supplied shards into a single output file, using the supplied
+ // key.
+ Status Merge(const std::function<T(absl::string_view)>& get_key,
+ const std::vector<std::string>& shard_files,
+ absl::string_view output_file);
+
+ // Deletes the supplied shard files.
+ Status Delete(std::vector<std::string> shard_files);
+
+ private:
+ std::unique_ptr<MultiSortedReader<T>> multi_reader_;
+ std::unique_ptr<RecordWriter> writer_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // INTERNAL_UTIL_RECORDIO_H_
diff --git a/util/recordio_test.cc b/util/recordio_test.cc
new file mode 100644
index 0000000..5ba4fdf
--- /dev/null
+++ b/util/recordio_test.cc
@@ -0,0 +1,508 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/recordio.h"
+
+#include <fstream>
+
+#include "crypto/context.h"
+#include "util/file_test.pb.h"
+#include "util/proto_util.h"
+#include "util/status.inc"
+#include "util/status_testing.inc"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/random/random.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+
+namespace private_join_and_compute {
+namespace {
+
+using ::private_join_and_compute::testing::TestProto;
+using ::testing::ElementsAreArray;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using testing::IsOkAndHolds;
+using testing::StatusIs;
+using ::testing::TempDir;
+
+std::string GetTestPBWithDummyAsStr(const std::string& data,
+ const std::string& dummy) {
+ TestProto test_proto;
+ test_proto.set_record(data);
+ test_proto.set_dummy(dummy);
+ return ProtoUtils::ToString(test_proto);
+}
+
+void ExpectFileContainsRecords(absl::string_view filename,
+ const std::vector<std::string>& expected_ids) {
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ std::vector<std::string> ids_read;
+ EXPECT_OK(reader->Open(filename));
+ while (reader->HasMore().value()) {
+ std::string raw_record;
+ EXPECT_OK(reader->Read(&raw_record));
+ ids_read.push_back(ProtoUtils::FromString<TestProto>(raw_record).record());
+ }
+ EXPECT_THAT(ids_read, ElementsAreArray(expected_ids));
+}
+
+TestProto GetRecord(const std::string& id) {
+ TestProto record;
+ record.set_record(id);
+ return record;
+}
+
+void ExpectInternalErrorWithSubstring(const Status& status,
+ const std::string& substring) {
+ EXPECT_THAT(status,
+ StatusIs(private_join_and_compute::StatusCode::kInternal, HasSubstr(substring)));
+}
+
+TEST(FileTest, WriteRecordThenReadTest) {
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ std::string actual;
+ EXPECT_OK(rr->Read(&actual));
+ EXPECT_EQ("data", actual);
+ EXPECT_OK(rr->Close());
+}
+
+TEST(FileTest, CannotOpenIfAlreadyOpened) {
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ EXPECT_FALSE(rr->Open(TempDir() + "test_file.txt").ok());
+}
+
+TEST(FileTest, OpensIfClosed) {
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rr->Close());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+}
+
+TEST(FileTest, WriteMultipleRecordsThenReadTest) {
+ Context ctx;
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("the first record."));
+ char written2_char[] = "raw\0record";
+ std::string written2(written2_char, 10);
+ EXPECT_OK(rw->Write(written2));
+ std::string num_bytes = ctx.CreateBigNum(1111111111).ToBytes();
+ EXPECT_OK(rw->Write(num_bytes));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ std::string read;
+ EXPECT_TRUE(rr->HasMore().value());
+ EXPECT_OK(rr->Read(&read));
+ EXPECT_EQ("the first record.", read);
+ EXPECT_TRUE(rr->HasMore().value());
+ std::string raw_read;
+ EXPECT_OK(rr->Read(&raw_read));
+ EXPECT_EQ(written2, raw_read);
+ EXPECT_NE("raw", raw_read);
+ EXPECT_EQ(10, raw_read.size());
+ EXPECT_TRUE(rr->HasMore().value());
+ EXPECT_OK(rr->Read(&read));
+ EXPECT_EQ(num_bytes, read);
+ EXPECT_FALSE(rr->HasMore().value());
+ EXPECT_OK(rr->Close());
+}
+
+TEST(FileTest, MultiSortReaderReadsInSortedOrder) {
+ std::vector<std::string> filenames({TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ TempDir() + "test_file2"});
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ std::vector<std::string> records(
+ {std::string("1\00", 3), std::string("1\01", 3), std::string("1\02", 3),
+ std::string("1\03", 3), std::string("1\04", 3), std::string("1\05", 3)});
+ EXPECT_OK(rw->Write(records[4]));
+ EXPECT_OK(rw->Write(records[5]));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write(records[2]));
+ EXPECT_OK(rw->Write(records[3]));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[2]));
+ EXPECT_OK(rw->Write(records[0]));
+ EXPECT_OK(rw->Write(records[1]));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames));
+ std::string data;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[0], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[1], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[2], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[3], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[4], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[5], data);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_FALSE(msr->Open(filenames).ok());
+ EXPECT_OK(msr->Close());
+ EXPECT_OK(msr->Open(filenames));
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, MultiSortReaderSortsBasedOnProtoKeyField) {
+ std::vector<std::string> filenames({
+ TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ });
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("1", "tiny")));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("3", "ti")));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("2", "tin")));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("4", "t")));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames, [](absl::string_view raw_data) {
+ return ProtoUtils::FromString<TestProto>(raw_data).record();
+ }));
+ std::string data;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("1", "tiny"), data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("2", "tin"), data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("3", "ti"), data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("4", "t"), data);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, MultiSortReaderReadsIndicesAsWell) {
+ std::vector<std::string> filenames({
+ TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ });
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ EXPECT_OK(rw->Write("1"));
+ EXPECT_OK(rw->Write("3"));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write("2"));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames));
+ std::string data;
+ int index;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(1, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, MultiSortReaderReadsDuplicateRecordsInOrderOfTheFileIndex) {
+ std::vector<std::string> filenames({
+ TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ });
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ EXPECT_OK(rw->Write("1"));
+ EXPECT_OK(rw->Write("2"));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write("2"));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames));
+ std::string data;
+ int index;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(1, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, LineReaderTest) {
+ std::ofstream ofs(TempDir() + "test_file.txt");
+ ofs << "Line1\nLine2\n\n";
+ ofs.close();
+ auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader());
+ EXPECT_OK(lr->Open(TempDir() + "test_file.txt"));
+ std::string line;
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line1", line);
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line2", line);
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("", line);
+ EXPECT_FALSE(lr->HasMore().value());
+ EXPECT_OK(lr->Close());
+}
+
+TEST(FileTest, LineReaderTestWithoutNewline) {
+ std::ofstream ofs(TempDir() + "test_file.txt");
+ ofs << "Line1\nLine2";
+ ofs.close();
+ auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader());
+ EXPECT_OK(lr->Open(TempDir() + "test_file.txt"));
+ std::string line;
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line1", line);
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line2", line);
+ EXPECT_FALSE(lr->HasMore().value());
+ EXPECT_OK(lr->Close());
+}
+
+TEST(FileTest, LineWriterTest) {
+ auto rw = std::unique_ptr<LineWriter>(LineWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ std::string actual;
+ EXPECT_OK(rr->Read(&actual));
+ EXPECT_EQ("data", actual);
+ EXPECT_OK(rr->Close());
+}
+
+TEST(ShardingWriterTest, WritesInShards) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/1);
+ writer->SetShardPrefix(TempDir() + "test_file");
+
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11"))));
+ EXPECT_THAT(writer->Close(),
+ IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ TempDir() + "test_file2"})));
+
+ ExpectFileContainsRecords(TempDir() + "test_file0", {"22"});
+ ExpectFileContainsRecords(TempDir() + "test_file1", {"33"});
+ ExpectFileContainsRecords(TempDir() + "test_file2", {"11"});
+}
+
+TEST(ShardingWriterTest, WritesInSortedShards) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/100);
+ writer->SetShardPrefix(TempDir() + "test_file");
+
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11"))));
+ EXPECT_THAT(writer->Close(),
+ IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0"})));
+
+ ExpectFileContainsRecords(TempDir() + "test_file0", {"11", "22", "33"});
+}
+
+TEST(ShardingWriterTest, CreatesNoShardsWhenNoRecordsWritten) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/1);
+ writer->SetShardPrefix(TempDir() + "test_file");
+ EXPECT_THAT(writer->Close(), IsOkAndHolds(IsEmpty()));
+}
+
+TEST(ShardingWriterTest, FailsIfWriteBeforeSettingOutputFilenames) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/100);
+ ExpectInternalErrorWithSubstring(
+ writer->Write(ProtoUtils::ToString(GetRecord("22"))),
+ "Must call SetShardPrefix before calling Write.");
+}
+
+TEST(ShardingWriterTest, FailsIfCloseBeforeSettingOutputFilenames) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/100);
+ ExpectInternalErrorWithSubstring(
+ writer->Close().status(),
+ "Must call SetShardPrefix before calling Close.");
+}
+
+TEST(ShardingMergerTest, MergesMultipleFilesCorrectly) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ EXPECT_OK(writer->Open(TempDir() + "test_file0"));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66"))));
+ EXPECT_OK(writer->Close());
+ EXPECT_OK(writer->Open(TempDir() + "test_file1"));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("77"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("99"))));
+ EXPECT_OK(writer->Close());
+
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Merge(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ {TempDir() + "test_file0", TempDir() + "test_file1"},
+ TempDir() + "output"));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_OK(reader->Open(TempDir() + "output"));
+ std::string record;
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("11", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("77", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("99", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_FALSE(reader->HasMore().value());
+ EXPECT_OK(reader->Close());
+}
+
+TEST(ShardingMergerTest, MergesSingleFileCorrectly) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ ASSERT_OK(writer->Open(TempDir() + "test_file0"));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44"))));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66"))));
+ ASSERT_OK(writer->Close());
+
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Merge(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ {TempDir() + "test_file0"}, TempDir() + "output"));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_OK(reader->Open(TempDir() + "output"));
+ std::string record;
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_FALSE(reader->HasMore().value());
+ EXPECT_OK(reader->Close());
+}
+
+TEST(ShardingMergerTest, CreatesEmptyFileIfNoShardsProvided) {
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Merge(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ {} /* no shard files */, TempDir() + "output"));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_OK(reader->Open(TempDir() + "output"));
+ EXPECT_FALSE(reader->HasMore().value());
+ EXPECT_OK(reader->Close());
+}
+
+TEST(ShardingMergerTest, DeletesFiles) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ ASSERT_OK(writer->Open(TempDir() + "test_file0"));
+ ASSERT_OK(writer->Close());
+ ASSERT_OK(writer->Open(TempDir() + "test_file1"));
+ ASSERT_OK(writer->Close());
+ ASSERT_OK(writer->Open(TempDir() + "test_file2"));
+ ASSERT_OK(writer->Close());
+
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Delete({TempDir() + "test_file0", TempDir() + "test_file1",
+ TempDir() + "test_file2"}));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_FALSE(reader->Open(TempDir() + "test_file0").ok());
+ EXPECT_FALSE(reader->Open(TempDir() + "test_file1").ok());
+ EXPECT_FALSE(reader->Open(TempDir() + "test_file2").ok());
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/util/status.inc b/util/status.inc
index aa84e4b..e645841 100644
--- a/util/status.inc
+++ b/util/status.inc
@@ -21,7 +21,7 @@
namespace private_join_and_compute {
// Aliases StatusCode to be compatible with our code.
using StatusCode = ::absl::StatusCode;
-// Aliases Status, StatusOr and canonical errors. This alias exists for
+// Aliases Status, StatusOr and canonical errors. This alias exists for
// historical reasons (when this library had a fork of absl::Status).
using Status = absl::Status;
template <typename T>
diff --git a/util/status_macros.h b/util/status_macros.h
index b760085..005239c 100644
--- a/util/status_macros.h
+++ b/util/status_macros.h
@@ -37,6 +37,21 @@
} \
lhs = std::move(statusor).value()
+// Helper macro that checks if the given expression evaluates to a
+// Status with Status OK. If not, returns the error status. Example:
+// RETURN_IF_ERROR(expression);
+#define RETURN_IF_ERROR(expr) \
+ PRIVACY_BLINDERS_RETURN_IF_ERROR_IMPL_( \
+ PRIVACY_BLINDERS_STATUS_MACROS_IMPL_CONCAT_(status_value, __LINE__), \
+ expr)
+
+// Internal helper.
+#define PRIVACY_BLINDERS_RETURN_IF_ERROR_IMPL_(status, expr) \
+ auto status = (expr); \
+ if (ABSL_PREDICT_FALSE(!status.ok())) { \
+ return status; \
+ }
+
// Internal helper for concatenating macro values.
#define PRIVACY_BLINDERS_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
#define PRIVACY_BLINDERS_STATUS_MACROS_IMPL_CONCAT_(x, y) \
diff --git a/util/status_matchers.h b/util/status_matchers.h
new file mode 100644
index 0000000..1ec6890
--- /dev/null
+++ b/util/status_matchers.h
@@ -0,0 +1,258 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef UTIL_STATUS_MATCHERS_H_
+#define UTIL_STATUS_MATCHERS_H_
+
+#include "util/status.inc"
+#include <gmock/gmock.h>
+
+namespace private_join_and_compute {
+namespace testing {
+
+#ifdef GTEST_HAS_STATUS_MATCHERS
+
+using ::testing::status::IsOk;
+using ::testing::status::IsOkAndHolds;
+using ::testing::status::StatusIs;
+
+#else // GTEST_HAS_STATUS_MATCHERS
+
+// copybara:strip_begin(Remove sensitive comments)
+//
+// This is partially copied from the code in
+// google3/testing/base/public/gmock_utils/status-matchers.h
+// but is somewhat simplified to only include the things we need for the PJC
+// library.
+//
+// copybara:strip_end
+
+namespace internal {
+
+// This function and its overload allow the same matcher to be used for Status
+// and StatusOr tests.
+inline Status GetStatus(const Status& status) { return status; }
+
+template <typename T>
+inline Status GetStatus(const StatusOr<T>& statusor) {
+ return statusor.status();
+}
+
+template <typename StatusType>
+class StatusIsImpl : public ::testing::MatcherInterface<StatusType> {
+ public:
+ StatusIsImpl(const ::testing::Matcher<StatusCode>& code,
+ const ::testing::Matcher<const std::string&>& message)
+ : code_(code), message_(message) {}
+
+ bool MatchAndExplain(
+ StatusType status,
+ ::testing::MatchResultListener* listener) const override {
+ ::testing::StringMatchResultListener str_listener;
+ Status real_status = GetStatus(status);
+ if (!code_.MatchAndExplain(real_status.code(), &str_listener)) {
+ *listener << str_listener.str();
+ return false;
+ }
+ if (!message_.MatchAndExplain(
+ static_cast<std::string>(real_status.message()), &str_listener)) {
+ *listener << str_listener.str();
+ return false;
+ }
+ return true;
+ }
+
+ void DescribeTo(std::ostream* os) const override {
+ *os << "has a status code that ";
+ code_.DescribeTo(os);
+ *os << " and a message that ";
+ message_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "has a status code that ";
+ code_.DescribeNegationTo(os);
+ *os << " and a message that ";
+ message_.DescribeNegationTo(os);
+ }
+
+ private:
+ ::testing::Matcher<StatusCode> code_;
+ ::testing::Matcher<const std::string&> message_;
+};
+
+class StatusIsPoly {
+ public:
+ StatusIsPoly(::testing::Matcher<StatusCode>&& code,
+ ::testing::Matcher<const std::string&>&& message)
+ : code_(code), message_(message) {}
+
+ // Converts this polymorphic matcher to a monomorphic matcher.
+ template <typename StatusType>
+ operator ::testing::Matcher<StatusType>() const {
+ return ::testing::Matcher<StatusType>(
+ new StatusIsImpl<StatusType>(code_, message_));
+ }
+
+ private:
+ ::testing::Matcher<StatusCode> code_;
+ ::testing::Matcher<const std::string&> message_;
+};
+
+} // namespace internal
+
+// This function allows us to avoid a template parameter when writing tests, so
+// that we can transparently test both Status and StatusOr returns.
+inline internal::StatusIsPoly StatusIs(
+ ::testing::Matcher<StatusCode>&& code,
+ ::testing::Matcher<const std::string&>&& message) {
+ return internal::StatusIsPoly(
+ std::forward<::testing::Matcher<StatusCode>>(code),
+ std::forward<::testing::Matcher<const std::string&>>(message));
+}
+
+// copybara:strip_begin(Remove sensitive comments)
+//
+// This is partially copied from the code in
+// third_party/absl/status/statusor_test.cc
+//
+// copybara:strip_end
+
+// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a
+// reference to StatusOr<T>.
+template <typename StatusOrType>
+class IsOkAndHoldsMatcherImpl
+ : public ::testing::MatcherInterface<StatusOrType> {
+ public:
+ typedef
+ typename std::remove_reference<StatusOrType>::type::value_type value_type;
+
+ template <typename InnerMatcher>
+ explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher)
+ : inner_matcher_(::testing::SafeMatcherCast<const value_type&>(
+ std::forward<InnerMatcher>(inner_matcher))) {}
+
+ void DescribeTo(std::ostream* os) const override {
+ *os << "is OK and has a value that ";
+ inner_matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "isn't OK or has a value that ";
+ inner_matcher_.DescribeNegationTo(os);
+ }
+
+ bool MatchAndExplain(
+ StatusOrType actual_value,
+ ::testing::MatchResultListener* result_listener) const override {
+ if (!actual_value.ok()) {
+ *result_listener << "which has status " << actual_value.status();
+ return false;
+ }
+
+ ::testing::StringMatchResultListener inner_listener;
+ const bool matches =
+ inner_matcher_.MatchAndExplain(*actual_value, &inner_listener);
+ const std::string inner_explanation = inner_listener.str();
+ if (!inner_explanation.empty()) {
+ *result_listener << "which contains value "
+ << ::testing::PrintToString(*actual_value) << ", "
+ << inner_explanation;
+ }
+ return matches;
+ }
+
+ private:
+ const ::testing::Matcher<const value_type&> inner_matcher_;
+};
+
+// Implements IsOkAndHolds(m) as a polymorphic matcher.
+template <typename InnerMatcher>
+class IsOkAndHoldsMatcher {
+ public:
+ explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher)
+ : inner_matcher_(std::move(inner_matcher)) {}
+
+ // Converts this polymorphic matcher to a monomorphic matcher of the
+ // given type. StatusOrType can be either StatusOr<T> or a
+ // reference to StatusOr<T>.
+ template <typename StatusOrType>
+ operator ::testing::Matcher<StatusOrType>() const { // NOLINT
+ return ::testing::Matcher<StatusOrType>(
+ new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_));
+ }
+
+ private:
+ const InnerMatcher inner_matcher_;
+};
+
+// Monomorphic implementation of matcher IsOk() for a given type T.
+// T can be Status, StatusOr<>, or a reference to either of them.
+template <typename T>
+class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> {
+ public:
+ void DescribeTo(std::ostream* os) const override { *os << "is OK"; }
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "is not OK";
+ }
+ bool MatchAndExplain(T actual_value,
+ ::testing::MatchResultListener*) const override {
+ return GetStatus(actual_value).ok();
+ }
+};
+
+// Implements IsOk() as a polymorphic matcher.
+class IsOkMatcher {
+ public:
+ template <typename T>
+ operator ::testing::Matcher<T>() const { // NOLINT
+ return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<T>());
+ }
+};
+
+// Returns a gMock matcher that matches a StatusOr<> whose status is
+// OK and whose value matches the inner matcher.
+template <typename InnerMatcher>
+IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> IsOkAndHolds(
+ InnerMatcher&& inner_matcher) {
+ return IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>(
+ std::forward<InnerMatcher>(inner_matcher));
+}
+
+// Returns a gMock matcher that matches a Status or StatusOr<> which is OK.
+inline IsOkMatcher IsOk() { return IsOkMatcher(); }
+
+#endif // GTEST_HAS_STATUS_MATCHERS
+
+} // namespace testing
+} // namespace private_join_and_compute
+
+#endif // UTIL_STATUS_MATCHERS_H_
diff --git a/util/status_testing.h b/util/status_testing.h
new file mode 100644
index 0000000..ab269fa
--- /dev/null
+++ b/util/status_testing.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef UTIL_STATUS_TESTING_H_
+#define UTIL_STATUS_TESTING_H_
+
+#include "util/status.inc"
+#include <gmock/gmock.h>
+
+#ifndef GTEST_HAS_STATUS_MATCHERS
+
+#define ASSERT_OK(expr) \
+ PRIVACY_BLINDERS_ASSERT_OK_IMPL_( \
+ PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), expr)
+
+#define PRIVACY_BLINDERS_ASSERT_OK_IMPL_(status, expr) \
+ auto status = (expr); \
+ ASSERT_THAT(status.ok(), ::testing::Eq(true));
+
+#define EXPECT_OK(expr) \
+ PRIVACY_BLINDERS_EXPECT_OK_IMPL_( \
+ PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), expr)
+
+#define PRIVACY_BLINDERS_EXPECT_OK_IMPL_(status, expr) \
+ auto status = (expr); \
+ EXPECT_THAT(status.ok(), ::testing::Eq(true));
+
+#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
+ PRIVACY_BLINDERS_ASSERT_OK_AND_ASSIGN_IMPL_( \
+ PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(_status_or_value, \
+ __LINE__), \
+ lhs, rexpr)
+
+#define PRIVACY_BLINDERS_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ ASSERT_THAT(statusor.ok(), ::testing::Eq(true)); \
+ lhs = std::move(statusor).value()
+
+// Internal helper for concatenating macro values.
+#define PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) x##y
+#define PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(x, y) \
+ PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y)
+
+#endif // GTEST_HAS_STATUS_MATCHERS
+
+#endif // UTIL_STATUS_TESTING_H_
diff --git a/util/status_testing.inc b/util/status_testing.inc
new file mode 100644
index 0000000..1f07284
--- /dev/null
+++ b/util/status_testing.inc
@@ -0,0 +1,17 @@
+/*
+ * Copyright 2019 Google Inc.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/status_matchers.h"
+#include "util/status_testing.h"