diff options
author | Karn Seth <karn@google.com> | 2021-01-13 19:17:50 +0000 |
---|---|---|
committer | Karn Seth <karn@google.com> | 2021-01-13 19:17:50 +0000 |
commit | 52c605f88b976d3ec386b09af0e72dec1e40d9a4 (patch) | |
tree | a3f26085acc24e53c31fb50ff3515ebffffb0dcc | |
parent | 884e999bde8f6c48e81c239eed95b7fcbaeb70ca (diff) | |
download | private-join-and-compute-52c605f88b976d3ec386b09af0e72dec1e40d9a4.tar.gz |
adds libraries for status testing, slight modifications to bignum
-rw-r--r-- | WORKSPACE | 16 | ||||
-rw-r--r-- | crypto/big_num.cc | 21 | ||||
-rw-r--r-- | crypto/big_num.h | 8 | ||||
-rw-r--r-- | crypto/paillier.cc | 1 | ||||
-rw-r--r-- | private_join_and_compute_rpc_impl.h | 2 | ||||
-rw-r--r-- | util/BUILD | 110 | ||||
-rw-r--r-- | util/file.cc | 76 | ||||
-rw-r--r-- | util/file.h | 114 | ||||
-rw-r--r-- | util/file_posix.cc | 167 | ||||
-rw-r--r-- | util/file_test.cc | 148 | ||||
-rw-r--r-- | util/file_test.proto | 23 | ||||
-rw-r--r-- | util/proto_util.h | 52 | ||||
-rw-r--r-- | util/proto_util_test.cc | 39 | ||||
-rw-r--r-- | util/recordio.cc | 607 | ||||
-rw-r--r-- | util/recordio.h | 273 | ||||
-rw-r--r-- | util/recordio_test.cc | 508 | ||||
-rw-r--r-- | util/status.inc | 2 | ||||
-rw-r--r-- | util/status_macros.h | 15 | ||||
-rw-r--r-- | util/status_matchers.h | 258 | ||||
-rw-r--r-- | util/status_testing.h | 74 | ||||
-rw-r--r-- | util/status_testing.inc | 17 |
21 files changed, 2517 insertions, 14 deletions
@@ -68,6 +68,22 @@ http_archive( ], ) +# gtest. +git_repository( + name = "com_github_google_googletest", + commit = "703bd9caab50b139428cea1aaff9974ebee5742e", # tag = "release-1.10.0" + remote = "https://github.com/google/googletest.git", + shallow_since = "1570114335 -0400", +) + +# Protobuf +git_repository( + name = "com_google_protobuf", + remote = "https://github.com/protocolbuffers/protobuf.git", + commit = "9647a7c2356a9529754c07235a2877ee676c2fd0", + shallow_since = "1609366209 -0800", +) + load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") # Includes boringssl, and other dependencies. diff --git a/crypto/big_num.cc b/crypto/big_num.cc index f886b0f..bba3f85 100644 --- a/crypto/big_num.cc +++ b/crypto/big_num.cc @@ -27,6 +27,23 @@ namespace private_join_and_compute { +namespace { + +// Utility class for decimal string conversion. +class BnString { + public: + explicit BnString(char* bn_char) : bn_char_(bn_char) {} + + ~BnString() { OPENSSL_free(bn_char_); } + + std::string ToString() { return std::string(bn_char_); } + + private: + char* const bn_char_; +}; + +} // namespace + BigNum::BigNum(const BigNum& other) : bn_(BignumPtr(CHECK_NOTNULL(BN_dup(other.bn_.get())))), bn_ctx_(other.bn_ctx_) {} @@ -90,6 +107,10 @@ StatusOr<uint64_t> BigNum::ToIntValue() const { return val; } +std::string BigNum::ToDecimalString() const { + return BnString(BN_bn2dec(GetConstBignumPtr())).ToString(); +} + int BigNum::BitLength() const { return BN_num_bits(bn_.get()); } bool BigNum::IsPrime(double prime_error_probability) const { diff --git a/crypto/big_num.h b/crypto/big_num.h index da96885..702f0de 100644 --- a/crypto/big_num.h +++ b/crypto/big_num.h @@ -19,6 +19,7 @@ #include <stdint.h> #include <memory> +#include <ostream> #include <string> #include "crypto/openssl.inc" @@ -57,6 +58,9 @@ class ABSL_MUST_USE_RESULT BigNum { // error code if the value of *this is larger than 64 bits. StatusOr<uint64_t> ToIntValue() const; + // Returns a string representation of the BigNum as a decimal number. + std::string ToDecimalString() const; + // Returns the bit length of this BigNum. int BitLength() const; @@ -239,6 +243,10 @@ inline BigNum& operator>>=(BigNum& a, int n) { return a = a >> n; } inline BigNum& operator<<=(BigNum& a, int n) { return a = a << n; } +inline std::ostream& operator<<(std::ostream& strm, const BigNum& a) { + return strm << "BigNum(" << a.ToDecimalString() << ")"; +} + } // namespace private_join_and_compute #endif // CRYPTO_BIG_NUM_H_ diff --git a/crypto/paillier.cc b/crypto/paillier.cc index 39400b2..dcaaf5e 100644 --- a/crypto/paillier.cc +++ b/crypto/paillier.cc @@ -416,7 +416,6 @@ StatusOr<BigNum> PublicPaillier::EncryptUsingGeneratorAndRand( return c.ModMul(g_n_to_r, modulus_); } - StatusOr<BigNum> PublicPaillier::EncryptWithRand(const BigNum& m, const BigNum& r) const { if (r.Gcd(n_) != ctx_->One()) { diff --git a/private_join_and_compute_rpc_impl.h b/private_join_and_compute_rpc_impl.h index a5c68a6..5ae4bde 100644 --- a/private_join_and_compute_rpc_impl.h +++ b/private_join_and_compute_rpc_impl.h @@ -16,8 +16,6 @@ #ifndef OPEN_SOURCE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_ #define OPEN_SOURCE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_ -#define GLOG_NO_ABBREVIATED_SEVERITIES -#include "glog/logging.h" #include "include/grpcpp/grpcpp.h" #include "include/grpcpp/server_context.h" #include "include/grpcpp/support/status.h" @@ -14,6 +14,8 @@ # Build file for util folder in open-source Private Join and Compute. +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") + package( default_visibility = ["//visibility:public"], features = [ @@ -23,13 +25,12 @@ package( ) cc_library( - name = "status", - srcs = glob( - ["*.cc"], - ), - hdrs = glob(["*.h"]), + name = "status_includes", + hdrs = [ + "status.inc", + "status_macros.h", + ], deps = [ - "@com_github_glog_glog//:glog", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf_lite", @@ -37,12 +38,101 @@ cc_library( ) cc_library( - name = "status_includes", - hdrs = ["status.inc"], + name = "status_testing_includes", + hdrs = [ + "status_matchers.h", + "status_testing.h", + "status_testing.inc", + ], + deps = [ + ":status_includes", + "@com_github_google_googletest//:gtest", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "file", + srcs = [ + "file.cc", + "file_posix.cc", + ], + hdrs = [ + "file.h", + ], + deps = [ + ":status_includes", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "file_test", + size = "small", + srcs = [ + "file_test.cc", + ], + deps = [ + ":file", + "@com_github_google_googletest//:gtest_main", + ], +) + +grpc_proto_library( + name = "file_test_proto", + srcs = ["file_test.proto"], +) + +cc_library( + name = "proto_util", + hdrs = ["proto_util.h"], deps = [ - ":status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf_lite", + ], +) + +cc_test( + name = "proto_util_test", + size = "medium", + srcs = ["proto_util_test.cc"], + deps = [ + ":file_test_proto", + ":proto_util", + "@com_github_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "recordio", + srcs = [ + "recordio.cc", + ], + hdrs = ["recordio.h"], + deps = [ + ":file", + ":status_includes", + "@com_github_glog_glog//:glog", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_protobuf//:protobuf_lite", ], ) + +cc_test( + name = "recordio_test", + srcs = ["recordio_test.cc"], + deps = [ + ":file_test_proto", + ":proto_util", + ":recordio", + ":status_includes", + ":status_testing_includes", + "//crypto:bn_util", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + ], +) diff --git a/util/file.cc b/util/file.cc new file mode 100644 index 0000000..dae55be --- /dev/null +++ b/util/file.cc @@ -0,0 +1,76 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Common implementations. + +#include "util/file.h" + +#include <sstream> + +namespace private_join_and_compute { +namespace internal { +namespace { + +bool IsAbsolutePath(absl::string_view path) { + return !path.empty() && path[0] == '/'; +} + +bool EndsWithSlash(absl::string_view path) { + return !path.empty() && path[path.size() - 1] == '/'; +} + +} // namespace + +std::string JoinPathImpl(std::initializer_list<std::string> paths) { + std::string joined_path; + int size = paths.size(); + + int counter = 1; + for (auto it = paths.begin(); it != paths.end(); ++it, ++counter) { + std::string path = *it; + if (path.empty()) { + continue; + } + + if (it == paths.begin()) { + joined_path += path; + if (!EndsWithSlash(path)) { + joined_path += "/"; + } + continue; + } + + if (EndsWithSlash(path)) { + if (IsAbsolutePath(path)) { + joined_path += path.substr(1, path.size() - 2); + } else { + joined_path += path.substr(0, path.size() - 1); + } + } else { + if (IsAbsolutePath(path)) { + joined_path += path.substr(1); + } else { + joined_path += path; + } + } + if (counter != size) { + joined_path += "."; + } + } + return joined_path; +} + +} // namespace internal +} // namespace private_join_and_compute diff --git a/util/file.h b/util/file.h new file mode 100644 index 0000000..cb4911d --- /dev/null +++ b/util/file.h @@ -0,0 +1,114 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// copybara:strip_begin(internal comment) +// Abstract File class for different file systems. The implementation in +// file_posix.cc is used on NaCl, WASM and Open source, and the implementation +// in file_google3.cc is used on standard google3. +// copybara:strip_end + +#ifndef INTERNAL_UTIL_FILE_H_ +#define INTERNAL_UTIL_FILE_H_ + +#include <string> + +#include "util/status.inc" + +namespace private_join_and_compute { + +// Renames a file. Overwrites the new file if it exists. +// Returns Status::OK for success. +// Error code in case of an error depends on the underlying implementation. +Status RenameFile(absl::string_view from, absl::string_view to); + +// Deletes a file. +// Returns Status::OK for success. +// Error code in case of an error depends on the underlying implementation. +Status DeleteFile(absl::string_view file_name); + +class File { + public: + virtual ~File() = default; + + // Opens the file_name for file operations applicable based on mode. + // Returns Status::OK for success. + // Error code in case of an error depends on the underlying implementation. + virtual Status Open(absl::string_view file_name, absl::string_view mode) = 0; + + // Closes the opened file. Must be called after opening a file. + // Returns Status::OK for success. + // Error code in case of an error depends on the underlying implementation. + virtual Status Close() = 0; + + // Returns true if there are more data in the file to be read. + // Returns a status instead in case of an io error in determining if there is + // more data. + virtual StatusOr<bool> HasMore() = 0; + + // Returns a data string of size length from reading file if successful. + // Returns a status in case of an error. + // This would also return an error status if the read data size is less than + // the length since it indicates file corruption. + virtual StatusOr<std::string> Read(size_t length) = 0; + + // Returns a line as string from the file without the trailing '\n' (or "\r\n" + // in the case of Windows). + // + // Returns a status in case of an error. + virtual StatusOr<std::string> ReadLine() = 0; + + // Writes the given content of size length into the file. + // Error code in case of an error depends on the underlying implementation. + virtual Status Write(absl::string_view content, size_t length) = 0; + + // Returns a File object depending on the linked implementation. + // Caller takes the ownership. + static File* GetFile(); + + protected: + File() = default; +}; + +namespace internal { +std::string JoinPathImpl(std::initializer_list<std::string> paths); +} // namespace internal + +// Joins multiple paths together such that only the first argument directory +// structure is represented. A dot as a separator is added for other arguments. +// +// Arguments | JoinPath | +// ---------------------------+---------------------+ +// '/foo', 'bar' | /foo/bar | +// '/foo/', 'bar' | /foo/bar | +// '/foo', '/bar' | /foo/bar | +// '/foo', '/bar', '/baz' | /foo/bar.baz | +// +// All paths will be treated as relative paths, regardless of whether or not +// they start with a leading '/'. That is, all paths will be concatenated +// together, with the appropriate path separator inserted in between. +// After the first path, all paths will be joined with a dot instead of the path +// separator so that there is no level of directory after the first argument. +// Arguments must be convertible to string. +// +// Usage: +// string path = file::JoinPath("/tmp", dirname, filename); +template <typename... T> +std::string JoinPath(const T&... args) { + return internal::JoinPathImpl({args...}); +} + +} // namespace private_join_and_compute + +#endif // INTERNAL_UTIL_FILE_H_ diff --git a/util/file_posix.cc b/util/file_posix.cc new file mode 100644 index 0000000..70806ff --- /dev/null +++ b/util/file_posix.cc @@ -0,0 +1,167 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <limits.h> +#include <stdio.h> +#include <stdlib.h> + +#include "util/file.h" +#include "util/status.inc" +#include "absl/strings/str_cat.h" + +namespace private_join_and_compute { +namespace { + +class PosixFile : public File { + public: + PosixFile() : File(), f_(nullptr), current_fname_() {} + + ~PosixFile() override { + if (f_) Close().IgnoreError(); + } + + Status Open(absl::string_view file_name, absl::string_view mode) final { + if (nullptr != f_) { + return InternalError( + absl::StrCat("Open failed:", "File with name ", current_fname_, + " has already been opened, close it first.")); + } + f_ = fopen(file_name.data(), mode.data()); + if (nullptr == f_) { + return absl::NotFoundError( + absl::StrCat("Open failed:", "Error opening file ", file_name)); + } + current_fname_ = std::string(file_name); + return OkStatus(); + } + + Status Close() final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("Close failed:", "There is no opened file.")); + } + if (fclose(f_)) { + return InternalError( + absl::StrCat("Close failed:", "Error closing file ", current_fname_)); + } + f_ = nullptr; + return OkStatus(); + } + + StatusOr<bool> HasMore() final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("HasMore failed:", "There is no opened file.")); + } + if (feof(f_)) return false; + if (ferror(f_)) { + return InternalError(absl::StrCat( + "HasMore failed:", "Error indicator has been set for file ", + current_fname_)); + } + int c = getc(f_); + if (ferror(f_)) { + return InternalError(absl::StrCat( + "HasMore failed:", "Error reading a single character from the file ", + current_fname_)); + } + if (ungetc(c, f_) != c) { + return InternalError(absl::StrCat( + "HasMore failed:", "Error putting back the peeked character ", + "into the file ", current_fname_)); + } + return c != EOF; + } + + StatusOr<std::string> Read(size_t length) final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("Read failed:", "There is no opened file.")); + } + std::vector<char> data(length); + if (fread(data.data(), 1, length, f_) != length) { + return InternalError(absl::StrCat( + "condition failed:", "Error reading the file ", current_fname_)); + } + return std::string(data.begin(), data.end()); + } + + StatusOr<std::string> ReadLine() final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("ReadLine failed:", "There is no opened file.")); + } + if (fgets(buffer_, LINE_MAX, f_) == nullptr || ferror(f_)) { + return InternalError( + absl::StrCat("ReadLine failed:", "Error reading line from the file ", + current_fname_)); + } + std::string content; + int len = strlen(buffer_); + // Remove trailing '\n' if present. + if (len > 0 && buffer_[len - 1] == '\n') { + // Remove trailing '\r' if present (e.g. on Windows) + if (len > 1 && buffer_[len - 2] == '\r') { + content.append(buffer_, len - 2); + } else { + content.append(buffer_, len - 1); + } + } else { + // No trailing newline characters + content.append(buffer_, len); + } + return content; + } + + Status Write(absl::string_view content, size_t length) final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("ReadLine failed:", "There is no opened file.")); + } + if (fwrite(content.data(), 1, length, f_) != length) { + return InternalError(absl::StrCat( + "ReadLine failed:", "Error writing the given data into the file ", + current_fname_)); + } + return OkStatus(); + } + + private: + FILE* f_; + std::string current_fname_; + char buffer_[LINE_MAX]; +}; + +} // namespace + +File* File::GetFile() { return new PosixFile(); } + +Status RenameFile(absl::string_view from, absl::string_view to) { + if (0 != rename(from.data(), to.data())) { + return InternalError(absl::StrCat( + "RenameFile failed:", "Cannot rename file, ", from, " to file, ", to)); + } + return OkStatus(); +} + +Status DeleteFile(absl::string_view file_name) { + if (0 != remove(file_name.data())) { + return InternalError( + absl::StrCat("DeleteFile failed:", "Cannot delete file, ", file_name)); + } + return OkStatus(); +} + +} // namespace private_join_and_compute diff --git a/util/file_test.cc b/util/file_test.cc new file mode 100644 index 0000000..291f3c9 --- /dev/null +++ b/util/file_test.cc @@ -0,0 +1,148 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/file.h" + +#include "util/status.inc" +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace private_join_and_compute { +namespace { + +template <typename T1, typename T2> +void AssertOkAndHolds(const T1& expected_value, const StatusOr<T2>& status_or) { + EXPECT_TRUE(status_or.ok()) << status_or.status(); + EXPECT_EQ(expected_value, status_or.value()); +} + +class FileTest : public testing::Test { + public: + FileTest() : testing::Test(), f_(File::GetFile()) {} + + std::unique_ptr<File> f_; +}; + +TEST_F(FileTest, WriteDataThenReadTest) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok()); + EXPECT_TRUE(f_->Write("water", 4).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "rb").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("wat", f_->Read(3)); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("e", f_->Read(1)); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, ReadLineTest) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok()); + EXPECT_TRUE(f_->Write("Line1\nLine2\n\n", 13).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line1", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line2", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("", f_->ReadLine()); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, CannotOpenFileIfAnotherFileIsAlreadyOpened) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp1.txt", "w").ok()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFile) { + EXPECT_FALSE(f_->Close().ok()); + EXPECT_FALSE(f_->HasMore().ok()); + EXPECT_FALSE(f_->Read(1).ok()); + EXPECT_FALSE(f_->ReadLine().ok()); + EXPECT_FALSE(f_->Write("w", 1).ok()); +} + +TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFileAfterClosing) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_FALSE(f_->Close().ok()); + EXPECT_FALSE(f_->HasMore().ok()); + EXPECT_FALSE(f_->Read(1).ok()); + EXPECT_FALSE(f_->ReadLine().ok()); + EXPECT_FALSE(f_->Write("w", 1).ok()); +} + +TEST_F(FileTest, TestRename) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_TRUE(f_->Write("water", 5).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(RenameFile(testing::TempDir() + "/tmp.txt", + testing::TempDir() + "/tmp1.txt") + .ok()); + EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp1.txt", "r").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("water", f_->Read(5)); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, TestDelete) { + // Create file and delete it. + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_TRUE(f_->Write("water", 5).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(DeleteFile(testing::TempDir() + "/tmp.txt").ok()); + EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + + // Try to delete nonexistent file. + EXPECT_FALSE(DeleteFile(testing::TempDir() + "/tmp2.txt").ok()); +} + +TEST_F(FileTest, JoinPathWithMultipleArgs) { + std::string ret = JoinPath("/tmp", "foo", "bar/", "/baz/"); + EXPECT_EQ("/tmp/foo.bar.baz", ret); +} + +TEST_F(FileTest, JoinPathWithMultipleArgsStartingWithEndSlashDir) { + std::string ret = JoinPath("/tmp/", "foo", "bar/", "/baz/"); + EXPECT_EQ("/tmp/foo.bar.baz", ret); +} + +TEST_F(FileTest, ReadLineWithCarriageReturnsTest) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok()); + std::string file_string = "Line1\nLine2\r\nLine3\r\nLine4\n\n"; + EXPECT_TRUE(f_->Write(file_string, file_string.size()).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line1", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line2", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line3", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line4", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("", f_->ReadLine()); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/util/file_test.proto b/util/file_test.proto new file mode 100644 index 0000000..3bf8e96 --- /dev/null +++ b/util/file_test.proto @@ -0,0 +1,23 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package private_join_and_compute.testing; + +message TestProto { + optional bytes record = 1; + optional bytes dummy = 2; +} diff --git a/util/proto_util.h b/util/proto_util.h new file mode 100644 index 0000000..8239751 --- /dev/null +++ b/util/proto_util.h @@ -0,0 +1,52 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Protocol buffer related static utility functions. + +#ifndef INTERNAL_UTIL_PROTO_UTIL_H_ +#define INTERNAL_UTIL_PROTO_UTIL_H_ + +#include <sstream> +#include <string> + +#include "src/google/protobuf/message_lite.h" +#include "absl/strings/string_view.h" + +namespace private_join_and_compute { + +class ProtoUtils { + public: + template <typename ProtoType> + static ProtoType FromString(absl::string_view raw_data); + + static std::string ToString(const google::protobuf::MessageLite& record); +}; + +template <typename ProtoType> +inline ProtoType ProtoUtils::FromString(absl::string_view raw_data) { + ProtoType record; + record.ParseFromArray(raw_data.data(), raw_data.size()); + return record; +} + +inline std::string ProtoUtils::ToString(const google::protobuf::MessageLite& record) { + std::ostringstream record_str_stream; + record.SerializeToOstream(&record_str_stream); + return record_str_stream.str(); +} + +} // namespace private_join_and_compute + +#endif // INTERNAL_UTIL_PROTO_UTIL_H_ diff --git a/util/proto_util_test.cc b/util/proto_util_test.cc new file mode 100644 index 0000000..ecb486d --- /dev/null +++ b/util/proto_util_test.cc @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/proto_util.h" + +#include "util/file_test.pb.h" +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace private_join_and_compute { + +namespace { +using testing::TestProto; + +TEST(ProtoUtilsTest, ConvertsToAndFrom) { + TestProto expected_test_proto; + expected_test_proto.set_record("data"); + expected_test_proto.set_dummy("dummy"); + std::string serialized = ProtoUtils::ToString(expected_test_proto); + TestProto actual_test_proto = ProtoUtils::FromString<TestProto>(serialized); + EXPECT_EQ(actual_test_proto.record(), expected_test_proto.record()); + EXPECT_EQ(actual_test_proto.dummy(), expected_test_proto.dummy()); +} + +} // namespace + +} // namespace private_join_and_compute diff --git a/util/recordio.cc b/util/recordio.cc new file mode 100644 index 0000000..d509632 --- /dev/null +++ b/util/recordio.cc @@ -0,0 +1,607 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/recordio.h" + +#include <algorithm> +#include <functional> +#include <list> +#include <memory> +#include <queue> +#include <string> +#include <vector> + +#define GLOG_NO_ABBREVIATED_SEVERITIES +#include "glog/logging.h" +#include "src/google/protobuf/io/coded_stream.h" +#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "util/status.inc" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" + +namespace private_join_and_compute { + +namespace { + +// Max. size of a Varint32 (from proto references). +const uint32_t kMaxVarint32Size = 5; + +// Tries to read a Varint32 from the front of a given file. Returns false if the +// reading fails. +StatusOr<uint32_t> ExtractVarint32(File* file) { + // Keep reading a single character until one is found such that the top bit is + // 0; + std::string bytes_read = ""; + + size_t current_byte = 0; + ASSIGN_OR_RETURN(auto has_more, file->HasMore()); + while (current_byte < kMaxVarint32Size && has_more) { + auto maybe_last_byte = file->Read(1); + if (!maybe_last_byte.ok()) { + return maybe_last_byte.status(); + } + + bytes_read += maybe_last_byte.value(); + if (!(bytes_read.data()[current_byte] & 0x80)) { + break; + } + current_byte++; + // If we read the max number of bits and never found a "terminating" byte, + // return false. + if (current_byte >= kMaxVarint32Size) { + return InvalidArgumentError( + "ExtractVarint32: Failed to extract a Varint after reading max " + "number " + "of bytes."); + } + ASSIGN_OR_RETURN(has_more, file->HasMore()); + } + + google::protobuf::io::ArrayInputStream arrayInputStream(bytes_read.data(), + bytes_read.size()); + google::protobuf::io::CodedInputStream codedInputStream(&arrayInputStream); + uint32_t result; + codedInputStream.ReadVarint32(&result); + + return result; +} + +// Reads records from a file one at a time. +class RecordReaderImpl : public RecordReader { + public: + explicit RecordReaderImpl(File* file) : RecordReader(), in_(file) {} + + Status Open(absl::string_view filename) final { + return in_->Open(filename, "r"); + } + + Status Close() final { return in_->Close(); } + + StatusOr<bool> HasMore() final { + auto status_or_has_more = in_->HasMore(); + if (!status_or_has_more.ok()) { + LOG(ERROR) << status_or_has_more.status(); + } + return status_or_has_more; + } + + Status Read(std::string* raw_data) final { + raw_data->erase(); + auto maybe_record_size = ExtractVarint32(in_.get()); + if (!maybe_record_size.ok()) { + LOG(ERROR) << "RecordReader::Read: Couldn't read record size: " + << maybe_record_size.status(); + return maybe_record_size.status(); + } + uint32_t record_size = maybe_record_size.value(); + + auto status_or_data = in_->Read(record_size); + if (!status_or_data.ok()) { + LOG(ERROR) << status_or_data.status(); + return status_or_data.status(); + } + + raw_data->append(status_or_data.value()); + return OkStatus(); + } + + private: + std::unique_ptr<File> in_; +}; + +// Reads lines from a file one at a time. +class LineReader : public RecordReader { + public: + explicit LineReader(File* file) : RecordReader(), in_(file) {} + + Status Open(absl::string_view filename) final { + return in_->Open(filename, "r"); + } + + Status Close() final { return in_->Close(); } + + StatusOr<bool> HasMore() final { return in_->HasMore(); } + + Status Read(std::string* line) final { + line->erase(); + auto status_or_line = in_->ReadLine(); + if (!status_or_line.ok()) { + LOG(ERROR) << status_or_line.status(); + return status_or_line.status(); + } + line->append(status_or_line.value()); + return OkStatus(); + } + + private: + std::unique_ptr<File> in_; +}; + +template <typename T> +class MultiSortedReaderImpl : public MultiSortedReader<T> { + public: + explicit MultiSortedReaderImpl( + const std::function<RecordReader*()>& get_reader, + std::unique_ptr<std::function<T(absl::string_view)>> default_key = + nullptr) + : MultiSortedReader<T>(), + get_reader_(get_reader), + default_key_(std::move(default_key)), + key_(nullptr) {} + + Status Open(const std::vector<std::string>& filenames) override { + if (default_key_ == nullptr) { + return InvalidArgumentError("The sorting key is null."); + } + return Open(filenames, *default_key_); + } + + Status Open(const std::vector<std::string>& filenames, + const std::function<T(absl::string_view)>& key) override { + if (!readers_.empty()) { + return InternalError("There are files not closed, call Close() first."); + } + key_ = absl::make_unique<std::function<T(absl::string_view)>>(key); + for (size_t i = 0; i < filenames.size(); ++i) { + this->readers_.push_back(std::unique_ptr<RecordReader>(get_reader_())); + auto open_status = this->readers_.back()->Open(filenames[i]); + if (!open_status.ok()) { + // Try to close the opened ones. + for (int j = i - 1; j >= 0; --j) { + // If closing fails as well, then any call to Open will fail as well + // since some of the files will remain opened. + auto status = this->readers_[j]->Close(); + if (!status.ok()) { + LOG(ERROR) << "Error closing file " << status; + } + this->readers_.pop_back(); + } + return open_status; + } + } + return OkStatus(); + } + + Status Close() override { + Status status = OkStatus(); + bool ret_val = + std::all_of(readers_.begin(), readers_.end(), + [&status](std::unique_ptr<RecordReader>& reader) { + Status close_status = reader->Close(); + if (!close_status.ok()) { + status = close_status; + return false; + } else { + return true; + } + }); + if (ret_val) { + readers_ = std::vector<std::unique_ptr<RecordReader>>(); + min_heap_ = std::priority_queue<HeapData, std::vector<HeapData>, + HeapDataGreater>(); + } + return status; + } + + StatusOr<bool> HasMore() override { + if (!min_heap_.empty()) { + return true; + } + Status status = OkStatus(); + for (const auto& reader : readers_) { + auto status_or_has_more = reader->HasMore(); + if (status_or_has_more.ok()) { + if (status_or_has_more.value()) { + return true; + } + } else { + status = status_or_has_more.status(); + } + } + if (status.ok()) { + // None of the readers has more. + return false; + } + return status; + } + + Status Read(std::string* data) override { return Read(data, nullptr); } + + Status Read(std::string* data, int* index) override { + if (min_heap_.empty()) { + for (size_t i = 0; i < readers_.size(); ++i) { + RETURN_IF_ERROR(this->ReadHeapDataFromReader(i)); + } + } + HeapData ret_data = min_heap_.top(); + data->assign(ret_data.data); + if (index != nullptr) *index = ret_data.index; + min_heap_.pop(); + return this->ReadHeapDataFromReader(ret_data.index); + } + + private: + Status ReadHeapDataFromReader(int index) { + std::string data; + auto status_or_has_more = readers_[index]->HasMore(); + if (!status_or_has_more.ok()) { + return status_or_has_more.status(); + } + if (status_or_has_more.value()) { + RETURN_IF_ERROR(readers_[index]->Read(&data)); + HeapData heap_data; + heap_data.key = (*key_)(data); + heap_data.data = data; + heap_data.index = index; + min_heap_.push(heap_data); + } + return OkStatus(); + } + + struct HeapData { + T key; + std::string data; + int index; + }; + + struct HeapDataGreater { + bool operator()(const HeapData& lhs, const HeapData& rhs) const { + return lhs.key > rhs.key; + } + }; + + const std::function<RecordReader*()> get_reader_; + std::unique_ptr<std::function<T(absl::string_view)>> default_key_; + std::unique_ptr<std::function<T(absl::string_view)>> key_; + std::vector<std::unique_ptr<RecordReader>> readers_; + std::priority_queue<HeapData, std::vector<HeapData>, HeapDataGreater> + min_heap_; +}; + +// Writes records to a file one at a time. +class RecordWriterImpl : public RecordWriter { + public: + explicit RecordWriterImpl(File* file) : RecordWriter(), out_(file) {} + + Status Open(absl::string_view filename) final { + return out_->Open(filename, "w"); + } + + Status Close() final { return out_->Close(); } + + Status Write(absl::string_view raw_data) final { + std::string delimited_output; + auto string_output = + absl::make_unique<google::protobuf::io::StringOutputStream>(&delimited_output); + auto coded_output = + absl::make_unique<google::protobuf::io::CodedOutputStream>(string_output.get()); + + // Write the delimited output. + coded_output->WriteVarint32(raw_data.size()); + coded_output->WriteString(std::string(raw_data)); + + // Force the serialization, which makes delimited_output safe to read. + coded_output = nullptr; + string_output = nullptr; + + return out_->Write(delimited_output, delimited_output.size()); + } + + private: + std::unique_ptr<File> out_; +}; + +// Writes lines to a file one at a time. +class LineWriterImpl : public LineWriter { + public: + explicit LineWriterImpl(File* file) : LineWriter(), out_(file) {} + + Status Open(absl::string_view filename) final { + return out_->Open(filename, "w"); + } + + Status Close() final { return out_->Close(); } + + Status Write(absl::string_view line) final { + RETURN_IF_ERROR(out_->Write(line.data(), line.size())); + return out_->Write("\n", 1); + } + + private: + std::unique_ptr<File> out_; +}; + +} // namespace + +RecordReader* RecordReader::GetLineReader() { + return RecordReader::GetLineReader(File::GetFile()); +} + +RecordReader* RecordReader::GetLineReader(File* file) { + return new LineReader(file); +} + +RecordReader* RecordReader::GetRecordReader() { + return RecordReader::GetRecordReader(File::GetFile()); +} + +RecordReader* RecordReader::GetRecordReader(File* file) { + return new RecordReaderImpl(file); +} + +RecordWriter* RecordWriter::Get() { return RecordWriter::Get(File::GetFile()); } + +RecordWriter* RecordWriter::Get(File* file) { + return new RecordWriterImpl(file); +} + +LineWriter* LineWriter::Get() { return LineWriter::Get(File::GetFile()); } + +LineWriter* LineWriter::Get(File* file) { return new LineWriterImpl(file); } + +template <typename T> +MultiSortedReader<T>* MultiSortedReader<T>::Get() { + return MultiSortedReader<T>::Get( + []() { return RecordReader::GetRecordReader(); }); +} + +template <> +MultiSortedReader<std::string>* MultiSortedReader<std::string>::Get( + const std::function<RecordReader*()>& get_reader) { + return new MultiSortedReaderImpl<std::string>( + get_reader, + absl::make_unique<std::function<std::string(absl::string_view)>>( + [](absl::string_view s) { return std::string(s); })); +} + +template <> +MultiSortedReader<int64_t>* MultiSortedReader<int64_t>::Get( + const std::function<RecordReader*()>& get_reader) { + return new MultiSortedReaderImpl<int64_t>( + get_reader, absl::make_unique<std::function<int64_t(absl::string_view)>>( + [](absl::string_view s) { return 0; })); +} + +template class MultiSortedReader<int64_t>; +template class MultiSortedReader<std::string>; + +namespace { + +std::string GetFilename(absl::string_view prefix, int32_t idx) { + return absl::StrCat(prefix, idx); +} + +template <typename T> +class ShardingWriterImpl : public ShardingWriter<T> { + public: + static Status AlreadyUnhealthyError() { + return InternalError("ShardingWriter: Already unhealthy."); + } + + explicit ShardingWriterImpl( + const std::function<T(absl::string_view)>& get_key, + int32_t max_bytes = 209715200, /* 200MB */ + std::unique_ptr<RecordWriter> record_writer = + absl::WrapUnique(RecordWriter::Get())) + : get_key_(get_key), + record_writer_(std::move(record_writer)), + max_bytes_(max_bytes), + cache_(), + bytes_written_(0), + current_file_idx_(0), + shard_files_(), + healthy_(true), + open_(false) {} + + void SetShardPrefix(absl::string_view shard_prefix) override { + absl::MutexLock lock(&mutex_); + open_ = true; + fnames_prefix_ = std::string(shard_prefix); + current_fname_ = GetFilename(fnames_prefix_, current_file_idx_); + } + + StatusOr<std::vector<std::string>> Close() override { + absl::MutexLock lock(&mutex_); + + auto retval = TryClose(); + + // Guarantee that the state is reset, even if TryClose fails. + fnames_prefix_ = ""; + current_fname_ = ""; + healthy_ = true; + cache_.clear(); + bytes_written_ = 0; + shard_files_.clear(); + current_file_idx_ = 0; + open_ = false; + + return retval; + } + + // Writes the supplied Record into the file. + // Returns true if the write operation was successful. + Status Write(absl::string_view raw_record) override { + absl::MutexLock lock(&mutex_); + if (!open_) { + return InternalError("Must call SetShardPrefix before calling Write."); + } + if (!healthy_) { + return AlreadyUnhealthyError(); + } + if (bytes_written_ > max_bytes_) { + RETURN_IF_ERROR(WriteCacheToFile()); + } + bytes_written_ += raw_record.size(); + cache_.push_back(std::string(raw_record)); + return OkStatus(); + } + + private: + Status WriteCacheToFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + if (!healthy_) return AlreadyUnhealthyError(); + if (cache_.empty()) return OkStatus(); + cache_.sort([this](absl::string_view r1, absl::string_view r2) { + return get_key_(r1) < get_key_(r2); + }); + if (!record_writer_->Open(current_fname_).ok()) { + healthy_ = false; + return InternalError( + absl::StrCat("Cannot open ", current_fname_, " for writing.")); + } + Status status = absl::OkStatus(); + for (absl::string_view r : cache_) { + if (!record_writer_->Write(r).ok()) { + healthy_ = false; + status = InternalError( + absl::StrCat("Cannot write record ", r, " to ", current_fname_)); + + break; + } + } + if (!record_writer_->Close().ok()) { + if (status.ok()) { + status = + InternalError(absl::StrCat("Cannot close ", current_fname_, ".")); + } else { + // Preserve the old status message. + LOG(WARNING) << "Cannot close " << current_fname_; + } + } + + shard_files_.push_back(current_fname_); + cache_.clear(); + bytes_written_ = 0; + ++current_file_idx_; + current_fname_ = GetFilename(fnames_prefix_, current_file_idx_); + return status; + } + + StatusOr<std::vector<std::string>> TryClose() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + if (!open_) { + return InternalError("Must call SetShardPrefix before calling Close."); + } + RETURN_IF_ERROR(WriteCacheToFile()); + + return {shard_files_}; + } + + absl::Mutex mutex_; + std::function<T(absl::string_view)> get_key_; + std::unique_ptr<RecordWriter> record_writer_ ABSL_GUARDED_BY(mutex_); + std::string fnames_prefix_ ABSL_GUARDED_BY(mutex_); + const int32_t max_bytes_ ABSL_GUARDED_BY(mutex_); + std::list<std::string> cache_ ABSL_GUARDED_BY(mutex_); + int32_t bytes_written_ ABSL_GUARDED_BY(mutex_); + int32_t current_file_idx_ ABSL_GUARDED_BY(mutex_); + std::string current_fname_ ABSL_GUARDED_BY(mutex_); + std::vector<std::string> shard_files_ ABSL_GUARDED_BY(mutex_); + bool healthy_ ABSL_GUARDED_BY(mutex_); + bool open_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace + +template <typename T> +std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get( + const std::function<T(absl::string_view)>& get_key, int32_t max_bytes) { + return absl::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes); +} + +// Test only. +template <typename T> +std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get( + const std::function<T(absl::string_view)>& get_key, int32_t max_bytes, + std::unique_ptr<RecordWriter> record_writer) { + return absl::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes, + std::move(record_writer)); +} + +template class ShardingWriter<int64_t>; +template class ShardingWriter<std::string>; + +template <typename T> +ShardMerger<T>::ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader, + std::unique_ptr<RecordWriter> writer) + : multi_reader_(std::move(multi_reader)), writer_(std::move(writer)) {} + +template <typename T> +Status ShardMerger<T>::Merge(const std::function<T(absl::string_view)>& get_key, + const std::vector<std::string>& shard_files, + absl::string_view output_file) { + if (shard_files.empty()) { + // Create an empty output file. + RETURN_IF_ERROR(writer_->Open(output_file)); + RETURN_IF_ERROR(writer_->Close()); + } + + // Multi-sorted-read all shards, and write the results to the supplied file. + std::vector<std::string> converted_shard_files; + converted_shard_files.reserve(shard_files.size()); + for (const auto& filename : shard_files) { + converted_shard_files.push_back(filename); + } + + RETURN_IF_ERROR(multi_reader_->Open(converted_shard_files, get_key)); + + RETURN_IF_ERROR(writer_->Open(output_file)); + + for (std::string record; multi_reader_->HasMore().value();) { + RETURN_IF_ERROR(multi_reader_->Read(&record)); + RETURN_IF_ERROR(writer_->Write(record)); + } + RETURN_IF_ERROR(writer_->Close()); + + RETURN_IF_ERROR(multi_reader_->Close()); + + return OkStatus(); +} + +template <typename T> +Status ShardMerger<T>::Delete(std::vector<std::string> shard_files) { + for (const auto& filename : shard_files) { + RETURN_IF_ERROR(DeleteFile(filename)); + } + + return OkStatus(); +} + +template class ShardMerger<int64_t>; +template class ShardMerger<std::string>; + +} // namespace private_join_and_compute diff --git a/util/recordio.h b/util/recordio.h new file mode 100644 index 0000000..32d11b4 --- /dev/null +++ b/util/recordio.h @@ -0,0 +1,273 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Defines file operations. +// RecordWriter generates output records that are binary data preceded with a +// Varint that explains the size of the records. The records provided to +// RecordWriter can be arbitrary binary data, but usually they will be +// serialized protobufs. +// +// copybara:strip_begin(internal comment) +// Note that this library is not in any way compatible with the Google3 RecordIo +// class, but rather uses the suggested proto-serialization described in +// https://developers.google.com/protocol-buffers/docs/techniques#streaming +// This encoding is also compatible with Java parseDelimitedFrom and +// writeDelimitedTo. +// copybara:strip_end +// +// RecordReader reads files written in the above format, and is also compatible +// with files written using the Java version of parseDelimitedFrom and +// writeDelimitedTo. +// +// LineWriter writes single lines to the output file. LineReader reads single +// lines from the input file. +// +// Note that all classes except ShardingWriter are not thread-safe: concurrent +// accesses must be protected by mutexes. + +#ifndef INTERNAL_UTIL_RECORDIO_H_ +#define INTERNAL_UTIL_RECORDIO_H_ + +#include <functional> +#include <memory> +#include <string> +#include <vector> + +#include "util/file.h" +#include "util/status.inc" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" + +namespace private_join_and_compute { + +// Interface for reading a single file. +class RecordReader { + public: + virtual ~RecordReader() = default; + + // RecordReader is neither copyable nor movable. + RecordReader(const RecordReader&) = delete; + RecordReader& operator=(const RecordReader&) = delete; + + // Opens the given file for reading. + virtual Status Open(absl::string_view file_name) = 0; + + // Closes any file object created via calling SingleFileReader::Open + virtual Status Close() = 0; + + // Returns true if there are more records in the file to be read. + virtual StatusOr<bool> HasMore() = 0; + + // Reads a record from the file (line or binary record). + virtual Status Read(std::string* record) = 0; + + // Returns a RecordReader for reading files line by line. + // Caller takes the ownership. + static RecordReader* GetLineReader(); + + // Returns a RecordReader for reading files in a record format compatible with + // RecordWriter below. + // Caller takes the ownership. + static RecordReader* GetRecordReader(); + + // Test only. + static RecordReader* GetLineReader(File* file); + static RecordReader* GetRecordReader(File* file); + + protected: + RecordReader() = default; +}; + +// Reads records one at a time in ascending order from multiple files, assuming +// each file stores records in ascending order. This class does the merge step +// for the external sorting. Templates T supported are string and int64. +template <typename T> +class MultiSortedReader { + public: + virtual ~MultiSortedReader() = default; + + // MultiSortedReader is neither copyable nor movable. + MultiSortedReader(const MultiSortedReader&) = delete; + MultiSortedReader& operator=(const MultiSortedReader&) = delete; + + // Opens the files generated with RecordWriterInterface. Records in each file + // are assumed to be sorted beforehand. + virtual Status Open(const std::vector<std::string>& filenames) = 0; + + // Same as Open above but also accepts a key function that is used to convert + // a string record into a value of type T, used when comparing the records. + // Records will be read from the file heads in ascending order of "key". + virtual Status Open(const std::vector<std::string>& filenames, + const std::function<T(absl::string_view)>& key) = 0; + + // Closes the file streams. + virtual Status Close() = 0; + + // Returns true if there are more records in the file to be read. + virtual StatusOr<bool> HasMore() = 0; + + // Reads a record data into <code>data</code> in ascending order. + // Erases the <code>data</code> before writing to it. + virtual Status Read(std::string* data) = 0; + + // Same as Read(string* data) but this also puts the index of the file + // where the data has been read from if index is not nullptr. + // Erases the <code>data</code> before writing to it. + virtual Status Read(std::string* data, int* index) = 0; + + // Returns a MultiSortedReader. + // Caller takes the ownership. + static MultiSortedReader<T>* Get(); + + // Test only. + static MultiSortedReader* Get( + const std::function<RecordReader*()>& get_reader); + + protected: + MultiSortedReader() = default; +}; + +class RecordWriter { + public: + virtual ~RecordWriter() = default; + + // RecordWriter is neither copyable nor movable. + RecordWriter(const RecordWriter&) = delete; + RecordWriter& operator=(const RecordWriter&) = delete; + + // Opens the given file for writing records. + virtual Status Open(absl::string_view file_name) = 0; + + // Closes the file stream and returns true if successful. + virtual Status Close() = 0; + + // Writes <code>raw_data</code> into the file as-is, with a delimiter + // specifying the data size. + virtual Status Write(absl::string_view raw_data) = 0; + + // Returns a RecordWriter. + // Caller takes the ownership. + static RecordWriter* Get(); + + // Test only. + static RecordWriter* Get(File* file); + + protected: + RecordWriter() = default; +}; + +class LineWriter { + public: + virtual ~LineWriter() = default; + + // LineWriter is neither copyable nor movable. + LineWriter(const LineWriter&) = delete; + LineWriter& operator=(const LineWriter&) = delete; + + // Opens the given file for writing lines. + virtual Status Open(absl::string_view file_name) = 0; + + // Closes the file stream and returns OkStatus if successful. + virtual Status Close() = 0; + + // Writes <code>line</code> into the file, with a trailing newline. + // Returns OkStatus if the write operation was successful. + virtual Status Write(absl::string_view line) = 0; + + // Returns a RecordWriter. + // Caller takes the ownership. + static LineWriter* Get(); + + // Test only. + static LineWriter* Get(File* file); + + protected: + LineWriter() = default; +}; + +// Writes Records to shard files, with each shard file internally sorted based +// on the supplied get_key method. +// +// This class is thread-safe. +template <typename T> +class ShardingWriter { + public: + virtual ~ShardingWriter() = default; + + // ShardingWriter is neither copyable nor copy-assignable. + ShardingWriter(const ShardingWriter&) = delete; + ShardingWriter& operator=(const ShardingWriter&) = delete; + + // Shards will be created with the supplied prefix. Must be called before + // Write. + virtual void SetShardPrefix(absl::string_view shard_prefix) = 0; + + // Clears the remaining cache, and returns the list of all shard files that + // were written since the last call to SetShardPrefix. Caller is responsible + // for merging and deleting shards. + // + // Returns InternalError if clearing the remaining cache fails. + virtual StatusOr<std::vector<std::string>> Close() = 0; + + // Writes the supplied str into the file. + // Implementations need not actually write the record on each call. Rather, + // they may cache records until max_bytes records have been cached, at which + // point they may sort the cache and write it to a shard file. + // + // Implementations must return InternalError if writing the cache fails, or + // if the shard prefix has not been set. + virtual Status Write(absl::string_view raw_data) = 0; + + // Returns a ShardingWriter that uses the supplied key to compare records. + // @param max_bytes: denotes the maximum size of each shard to write. + static std::unique_ptr<ShardingWriter> Get( + const std::function<T(absl::string_view)>& get_key, + int32_t max_bytes = 209715200 /* 200MB */); + + // Test only. + static std::unique_ptr<ShardingWriter> Get( + const std::function<T(absl::string_view)>& get_key, int32_t max_bytes, + std::unique_ptr<RecordWriter> record_writer); + + protected: + ShardingWriter() = default; +}; + +// Utility class to allow merging of sorted shards, and deleting of shards. +template <typename T> +class ShardMerger { + public: + explicit ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader = + absl::WrapUnique(MultiSortedReader<T>::Get()), + std::unique_ptr<RecordWriter> writer = + absl::WrapUnique(RecordWriter::Get())); + + // Merges the supplied shards into a single output file, using the supplied + // key. + Status Merge(const std::function<T(absl::string_view)>& get_key, + const std::vector<std::string>& shard_files, + absl::string_view output_file); + + // Deletes the supplied shard files. + Status Delete(std::vector<std::string> shard_files); + + private: + std::unique_ptr<MultiSortedReader<T>> multi_reader_; + std::unique_ptr<RecordWriter> writer_; +}; + +} // namespace private_join_and_compute + +#endif // INTERNAL_UTIL_RECORDIO_H_ diff --git a/util/recordio_test.cc b/util/recordio_test.cc new file mode 100644 index 0000000..5ba4fdf --- /dev/null +++ b/util/recordio_test.cc @@ -0,0 +1,508 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/recordio.h" + +#include <fstream> + +#include "crypto/context.h" +#include "util/file_test.pb.h" +#include "util/proto_util.h" +#include "util/status.inc" +#include "util/status_testing.inc" +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace private_join_and_compute { +namespace { + +using ::private_join_and_compute::testing::TestProto; +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using testing::IsOkAndHolds; +using testing::StatusIs; +using ::testing::TempDir; + +std::string GetTestPBWithDummyAsStr(const std::string& data, + const std::string& dummy) { + TestProto test_proto; + test_proto.set_record(data); + test_proto.set_dummy(dummy); + return ProtoUtils::ToString(test_proto); +} + +void ExpectFileContainsRecords(absl::string_view filename, + const std::vector<std::string>& expected_ids) { + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + std::vector<std::string> ids_read; + EXPECT_OK(reader->Open(filename)); + while (reader->HasMore().value()) { + std::string raw_record; + EXPECT_OK(reader->Read(&raw_record)); + ids_read.push_back(ProtoUtils::FromString<TestProto>(raw_record).record()); + } + EXPECT_THAT(ids_read, ElementsAreArray(expected_ids)); +} + +TestProto GetRecord(const std::string& id) { + TestProto record; + record.set_record(id); + return record; +} + +void ExpectInternalErrorWithSubstring(const Status& status, + const std::string& substring) { + EXPECT_THAT(status, + StatusIs(private_join_and_compute::StatusCode::kInternal, HasSubstr(substring))); +} + +TEST(FileTest, WriteRecordThenReadTest) { + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + std::string actual; + EXPECT_OK(rr->Read(&actual)); + EXPECT_EQ("data", actual); + EXPECT_OK(rr->Close()); +} + +TEST(FileTest, CannotOpenIfAlreadyOpened) { + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + EXPECT_FALSE(rr->Open(TempDir() + "test_file.txt").ok()); +} + +TEST(FileTest, OpensIfClosed) { + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rr->Close()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); +} + +TEST(FileTest, WriteMultipleRecordsThenReadTest) { + Context ctx; + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("the first record.")); + char written2_char[] = "raw\0record"; + std::string written2(written2_char, 10); + EXPECT_OK(rw->Write(written2)); + std::string num_bytes = ctx.CreateBigNum(1111111111).ToBytes(); + EXPECT_OK(rw->Write(num_bytes)); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + std::string read; + EXPECT_TRUE(rr->HasMore().value()); + EXPECT_OK(rr->Read(&read)); + EXPECT_EQ("the first record.", read); + EXPECT_TRUE(rr->HasMore().value()); + std::string raw_read; + EXPECT_OK(rr->Read(&raw_read)); + EXPECT_EQ(written2, raw_read); + EXPECT_NE("raw", raw_read); + EXPECT_EQ(10, raw_read.size()); + EXPECT_TRUE(rr->HasMore().value()); + EXPECT_OK(rr->Read(&read)); + EXPECT_EQ(num_bytes, read); + EXPECT_FALSE(rr->HasMore().value()); + EXPECT_OK(rr->Close()); +} + +TEST(FileTest, MultiSortReaderReadsInSortedOrder) { + std::vector<std::string> filenames({TempDir() + "test_file0", + TempDir() + "test_file1", + TempDir() + "test_file2"}); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + std::vector<std::string> records( + {std::string("1\00", 3), std::string("1\01", 3), std::string("1\02", 3), + std::string("1\03", 3), std::string("1\04", 3), std::string("1\05", 3)}); + EXPECT_OK(rw->Write(records[4])); + EXPECT_OK(rw->Write(records[5])); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write(records[2])); + EXPECT_OK(rw->Write(records[3])); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[2])); + EXPECT_OK(rw->Write(records[0])); + EXPECT_OK(rw->Write(records[1])); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames)); + std::string data; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[0], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[1], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[2], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[3], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[4], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[5], data); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_FALSE(msr->Open(filenames).ok()); + EXPECT_OK(msr->Close()); + EXPECT_OK(msr->Open(filenames)); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, MultiSortReaderSortsBasedOnProtoKeyField) { + std::vector<std::string> filenames({ + TempDir() + "test_file0", + TempDir() + "test_file1", + }); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("1", "tiny"))); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("3", "ti"))); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("2", "tin"))); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("4", "t"))); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames, [](absl::string_view raw_data) { + return ProtoUtils::FromString<TestProto>(raw_data).record(); + })); + std::string data; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("1", "tiny"), data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("2", "tin"), data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("3", "ti"), data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("4", "t"), data); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, MultiSortReaderReadsIndicesAsWell) { + std::vector<std::string> filenames({ + TempDir() + "test_file0", + TempDir() + "test_file1", + }); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + EXPECT_OK(rw->Write("1")); + EXPECT_OK(rw->Write("3")); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write("2")); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames)); + std::string data; + int index; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(1, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, MultiSortReaderReadsDuplicateRecordsInOrderOfTheFileIndex) { + std::vector<std::string> filenames({ + TempDir() + "test_file0", + TempDir() + "test_file1", + }); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + EXPECT_OK(rw->Write("1")); + EXPECT_OK(rw->Write("2")); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write("2")); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames)); + std::string data; + int index; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(1, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, LineReaderTest) { + std::ofstream ofs(TempDir() + "test_file.txt"); + ofs << "Line1\nLine2\n\n"; + ofs.close(); + auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader()); + EXPECT_OK(lr->Open(TempDir() + "test_file.txt")); + std::string line; + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line1", line); + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line2", line); + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("", line); + EXPECT_FALSE(lr->HasMore().value()); + EXPECT_OK(lr->Close()); +} + +TEST(FileTest, LineReaderTestWithoutNewline) { + std::ofstream ofs(TempDir() + "test_file.txt"); + ofs << "Line1\nLine2"; + ofs.close(); + auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader()); + EXPECT_OK(lr->Open(TempDir() + "test_file.txt")); + std::string line; + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line1", line); + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line2", line); + EXPECT_FALSE(lr->HasMore().value()); + EXPECT_OK(lr->Close()); +} + +TEST(FileTest, LineWriterTest) { + auto rw = std::unique_ptr<LineWriter>(LineWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + std::string actual; + EXPECT_OK(rr->Read(&actual)); + EXPECT_EQ("data", actual); + EXPECT_OK(rr->Close()); +} + +TEST(ShardingWriterTest, WritesInShards) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/1); + writer->SetShardPrefix(TempDir() + "test_file"); + + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11")))); + EXPECT_THAT(writer->Close(), + IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0", + TempDir() + "test_file1", + TempDir() + "test_file2"}))); + + ExpectFileContainsRecords(TempDir() + "test_file0", {"22"}); + ExpectFileContainsRecords(TempDir() + "test_file1", {"33"}); + ExpectFileContainsRecords(TempDir() + "test_file2", {"11"}); +} + +TEST(ShardingWriterTest, WritesInSortedShards) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/100); + writer->SetShardPrefix(TempDir() + "test_file"); + + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11")))); + EXPECT_THAT(writer->Close(), + IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0"}))); + + ExpectFileContainsRecords(TempDir() + "test_file0", {"11", "22", "33"}); +} + +TEST(ShardingWriterTest, CreatesNoShardsWhenNoRecordsWritten) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/1); + writer->SetShardPrefix(TempDir() + "test_file"); + EXPECT_THAT(writer->Close(), IsOkAndHolds(IsEmpty())); +} + +TEST(ShardingWriterTest, FailsIfWriteBeforeSettingOutputFilenames) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/100); + ExpectInternalErrorWithSubstring( + writer->Write(ProtoUtils::ToString(GetRecord("22"))), + "Must call SetShardPrefix before calling Write."); +} + +TEST(ShardingWriterTest, FailsIfCloseBeforeSettingOutputFilenames) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/100); + ExpectInternalErrorWithSubstring( + writer->Close().status(), + "Must call SetShardPrefix before calling Close."); +} + +TEST(ShardingMergerTest, MergesMultipleFilesCorrectly) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + EXPECT_OK(writer->Open(TempDir() + "test_file0")); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66")))); + EXPECT_OK(writer->Close()); + EXPECT_OK(writer->Open(TempDir() + "test_file1")); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("77")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("99")))); + EXPECT_OK(writer->Close()); + + ShardMerger<std::string> merger; + EXPECT_OK(merger.Merge( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + {TempDir() + "test_file0", TempDir() + "test_file1"}, + TempDir() + "output")); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_OK(reader->Open(TempDir() + "output")); + std::string record; + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("11", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("77", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("99", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_FALSE(reader->HasMore().value()); + EXPECT_OK(reader->Close()); +} + +TEST(ShardingMergerTest, MergesSingleFileCorrectly) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + ASSERT_OK(writer->Open(TempDir() + "test_file0")); + ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44")))); + ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66")))); + ASSERT_OK(writer->Close()); + + ShardMerger<std::string> merger; + EXPECT_OK(merger.Merge( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + {TempDir() + "test_file0"}, TempDir() + "output")); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_OK(reader->Open(TempDir() + "output")); + std::string record; + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_FALSE(reader->HasMore().value()); + EXPECT_OK(reader->Close()); +} + +TEST(ShardingMergerTest, CreatesEmptyFileIfNoShardsProvided) { + ShardMerger<std::string> merger; + EXPECT_OK(merger.Merge( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + {} /* no shard files */, TempDir() + "output")); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_OK(reader->Open(TempDir() + "output")); + EXPECT_FALSE(reader->HasMore().value()); + EXPECT_OK(reader->Close()); +} + +TEST(ShardingMergerTest, DeletesFiles) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + ASSERT_OK(writer->Open(TempDir() + "test_file0")); + ASSERT_OK(writer->Close()); + ASSERT_OK(writer->Open(TempDir() + "test_file1")); + ASSERT_OK(writer->Close()); + ASSERT_OK(writer->Open(TempDir() + "test_file2")); + ASSERT_OK(writer->Close()); + + ShardMerger<std::string> merger; + EXPECT_OK(merger.Delete({TempDir() + "test_file0", TempDir() + "test_file1", + TempDir() + "test_file2"})); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_FALSE(reader->Open(TempDir() + "test_file0").ok()); + EXPECT_FALSE(reader->Open(TempDir() + "test_file1").ok()); + EXPECT_FALSE(reader->Open(TempDir() + "test_file2").ok()); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/util/status.inc b/util/status.inc index aa84e4b..e645841 100644 --- a/util/status.inc +++ b/util/status.inc @@ -21,7 +21,7 @@ namespace private_join_and_compute { // Aliases StatusCode to be compatible with our code. using StatusCode = ::absl::StatusCode; -// Aliases Status, StatusOr and canonical errors. This alias exists for +// Aliases Status, StatusOr and canonical errors. This alias exists for // historical reasons (when this library had a fork of absl::Status). using Status = absl::Status; template <typename T> diff --git a/util/status_macros.h b/util/status_macros.h index b760085..005239c 100644 --- a/util/status_macros.h +++ b/util/status_macros.h @@ -37,6 +37,21 @@ } \ lhs = std::move(statusor).value() +// Helper macro that checks if the given expression evaluates to a +// Status with Status OK. If not, returns the error status. Example: +// RETURN_IF_ERROR(expression); +#define RETURN_IF_ERROR(expr) \ + PRIVACY_BLINDERS_RETURN_IF_ERROR_IMPL_( \ + PRIVACY_BLINDERS_STATUS_MACROS_IMPL_CONCAT_(status_value, __LINE__), \ + expr) + +// Internal helper. +#define PRIVACY_BLINDERS_RETURN_IF_ERROR_IMPL_(status, expr) \ + auto status = (expr); \ + if (ABSL_PREDICT_FALSE(!status.ok())) { \ + return status; \ + } + // Internal helper for concatenating macro values. #define PRIVACY_BLINDERS_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y #define PRIVACY_BLINDERS_STATUS_MACROS_IMPL_CONCAT_(x, y) \ diff --git a/util/status_matchers.h b/util/status_matchers.h new file mode 100644 index 0000000..1ec6890 --- /dev/null +++ b/util/status_matchers.h @@ -0,0 +1,258 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UTIL_STATUS_MATCHERS_H_ +#define UTIL_STATUS_MATCHERS_H_ + +#include "util/status.inc" +#include <gmock/gmock.h> + +namespace private_join_and_compute { +namespace testing { + +#ifdef GTEST_HAS_STATUS_MATCHERS + +using ::testing::status::IsOk; +using ::testing::status::IsOkAndHolds; +using ::testing::status::StatusIs; + +#else // GTEST_HAS_STATUS_MATCHERS + +// copybara:strip_begin(Remove sensitive comments) +// +// This is partially copied from the code in +// google3/testing/base/public/gmock_utils/status-matchers.h +// but is somewhat simplified to only include the things we need for the PJC +// library. +// +// copybara:strip_end + +namespace internal { + +// This function and its overload allow the same matcher to be used for Status +// and StatusOr tests. +inline Status GetStatus(const Status& status) { return status; } + +template <typename T> +inline Status GetStatus(const StatusOr<T>& statusor) { + return statusor.status(); +} + +template <typename StatusType> +class StatusIsImpl : public ::testing::MatcherInterface<StatusType> { + public: + StatusIsImpl(const ::testing::Matcher<StatusCode>& code, + const ::testing::Matcher<const std::string&>& message) + : code_(code), message_(message) {} + + bool MatchAndExplain( + StatusType status, + ::testing::MatchResultListener* listener) const override { + ::testing::StringMatchResultListener str_listener; + Status real_status = GetStatus(status); + if (!code_.MatchAndExplain(real_status.code(), &str_listener)) { + *listener << str_listener.str(); + return false; + } + if (!message_.MatchAndExplain( + static_cast<std::string>(real_status.message()), &str_listener)) { + *listener << str_listener.str(); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "has a status code that "; + code_.DescribeTo(os); + *os << " and a message that "; + message_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "has a status code that "; + code_.DescribeNegationTo(os); + *os << " and a message that "; + message_.DescribeNegationTo(os); + } + + private: + ::testing::Matcher<StatusCode> code_; + ::testing::Matcher<const std::string&> message_; +}; + +class StatusIsPoly { + public: + StatusIsPoly(::testing::Matcher<StatusCode>&& code, + ::testing::Matcher<const std::string&>&& message) + : code_(code), message_(message) {} + + // Converts this polymorphic matcher to a monomorphic matcher. + template <typename StatusType> + operator ::testing::Matcher<StatusType>() const { + return ::testing::Matcher<StatusType>( + new StatusIsImpl<StatusType>(code_, message_)); + } + + private: + ::testing::Matcher<StatusCode> code_; + ::testing::Matcher<const std::string&> message_; +}; + +} // namespace internal + +// This function allows us to avoid a template parameter when writing tests, so +// that we can transparently test both Status and StatusOr returns. +inline internal::StatusIsPoly StatusIs( + ::testing::Matcher<StatusCode>&& code, + ::testing::Matcher<const std::string&>&& message) { + return internal::StatusIsPoly( + std::forward<::testing::Matcher<StatusCode>>(code), + std::forward<::testing::Matcher<const std::string&>>(message)); +} + +// copybara:strip_begin(Remove sensitive comments) +// +// This is partially copied from the code in +// third_party/absl/status/statusor_test.cc +// +// copybara:strip_end + +// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a +// reference to StatusOr<T>. +template <typename StatusOrType> +class IsOkAndHoldsMatcherImpl + : public ::testing::MatcherInterface<StatusOrType> { + public: + typedef + typename std::remove_reference<StatusOrType>::type::value_type value_type; + + template <typename InnerMatcher> + explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) + : inner_matcher_(::testing::SafeMatcherCast<const value_type&>( + std::forward<InnerMatcher>(inner_matcher))) {} + + void DescribeTo(std::ostream* os) const override { + *os << "is OK and has a value that "; + inner_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "isn't OK or has a value that "; + inner_matcher_.DescribeNegationTo(os); + } + + bool MatchAndExplain( + StatusOrType actual_value, + ::testing::MatchResultListener* result_listener) const override { + if (!actual_value.ok()) { + *result_listener << "which has status " << actual_value.status(); + return false; + } + + ::testing::StringMatchResultListener inner_listener; + const bool matches = + inner_matcher_.MatchAndExplain(*actual_value, &inner_listener); + const std::string inner_explanation = inner_listener.str(); + if (!inner_explanation.empty()) { + *result_listener << "which contains value " + << ::testing::PrintToString(*actual_value) << ", " + << inner_explanation; + } + return matches; + } + + private: + const ::testing::Matcher<const value_type&> inner_matcher_; +}; + +// Implements IsOkAndHolds(m) as a polymorphic matcher. +template <typename InnerMatcher> +class IsOkAndHoldsMatcher { + public: + explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher) + : inner_matcher_(std::move(inner_matcher)) {} + + // Converts this polymorphic matcher to a monomorphic matcher of the + // given type. StatusOrType can be either StatusOr<T> or a + // reference to StatusOr<T>. + template <typename StatusOrType> + operator ::testing::Matcher<StatusOrType>() const { // NOLINT + return ::testing::Matcher<StatusOrType>( + new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_)); + } + + private: + const InnerMatcher inner_matcher_; +}; + +// Monomorphic implementation of matcher IsOk() for a given type T. +// T can be Status, StatusOr<>, or a reference to either of them. +template <typename T> +class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> { + public: + void DescribeTo(std::ostream* os) const override { *os << "is OK"; } + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not OK"; + } + bool MatchAndExplain(T actual_value, + ::testing::MatchResultListener*) const override { + return GetStatus(actual_value).ok(); + } +}; + +// Implements IsOk() as a polymorphic matcher. +class IsOkMatcher { + public: + template <typename T> + operator ::testing::Matcher<T>() const { // NOLINT + return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<T>()); + } +}; + +// Returns a gMock matcher that matches a StatusOr<> whose status is +// OK and whose value matches the inner matcher. +template <typename InnerMatcher> +IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> IsOkAndHolds( + InnerMatcher&& inner_matcher) { + return IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>( + std::forward<InnerMatcher>(inner_matcher)); +} + +// Returns a gMock matcher that matches a Status or StatusOr<> which is OK. +inline IsOkMatcher IsOk() { return IsOkMatcher(); } + +#endif // GTEST_HAS_STATUS_MATCHERS + +} // namespace testing +} // namespace private_join_and_compute + +#endif // UTIL_STATUS_MATCHERS_H_ diff --git a/util/status_testing.h b/util/status_testing.h new file mode 100644 index 0000000..ab269fa --- /dev/null +++ b/util/status_testing.h @@ -0,0 +1,74 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UTIL_STATUS_TESTING_H_ +#define UTIL_STATUS_TESTING_H_ + +#include "util/status.inc" +#include <gmock/gmock.h> + +#ifndef GTEST_HAS_STATUS_MATCHERS + +#define ASSERT_OK(expr) \ + PRIVACY_BLINDERS_ASSERT_OK_IMPL_( \ + PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), expr) + +#define PRIVACY_BLINDERS_ASSERT_OK_IMPL_(status, expr) \ + auto status = (expr); \ + ASSERT_THAT(status.ok(), ::testing::Eq(true)); + +#define EXPECT_OK(expr) \ + PRIVACY_BLINDERS_EXPECT_OK_IMPL_( \ + PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), expr) + +#define PRIVACY_BLINDERS_EXPECT_OK_IMPL_(status, expr) \ + auto status = (expr); \ + EXPECT_THAT(status.ok(), ::testing::Eq(true)); + +#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + PRIVACY_BLINDERS_ASSERT_OK_AND_ASSIGN_IMPL_( \ + PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(_status_or_value, \ + __LINE__), \ + lhs, rexpr) + +#define PRIVACY_BLINDERS_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_THAT(statusor.ok(), ::testing::Eq(true)); \ + lhs = std::move(statusor).value() + +// Internal helper for concatenating macro values. +#define PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) x##y +#define PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_(x, y) \ + PRIVACY_BLINDERS_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) + +#endif // GTEST_HAS_STATUS_MATCHERS + +#endif // UTIL_STATUS_TESTING_H_ diff --git a/util/status_testing.inc b/util/status_testing.inc new file mode 100644 index 0000000..1f07284 --- /dev/null +++ b/util/status_testing.inc @@ -0,0 +1,17 @@ +/* + * Copyright 2019 Google Inc. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "util/status_matchers.h" +#include "util/status_testing.h" |