diff options
author | Joe Tsai <joetsai@digital-static.net> | 2022-11-15 10:26:06 -0800 |
---|---|---|
committer | Christian Höppner <hoeppi@google.com> | 2022-12-01 15:36:18 +0000 |
commit | b2a7dfe48b6ae3aac5f05941373d4478c104aa6b (patch) | |
tree | 7276cf7d320fd86abba4fa3f530e4e5dccd00fe7 | |
parent | f0e23c7a8f8d55910f877091fdad71d3abb26d10 (diff) | |
download | golang-protobuf-b2a7dfe48b6ae3aac5f05941373d4478c104aa6b.tar.gz |
reflect/protoreflect: add Value.Equal method
The Value.Equal method compares Values in a way that is
analogous to the reflect.DeepEqual function.
Most of the implementation is moved from the "proto" package,
with proto.Equal implemented in terms of protoreflect.Value.Equal.
Change-Id: Ie9d5f6c073dc49fa59f78385c441db8137de5ec5
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/450775
TryBot-Bypass: Heschi Kreinick <heschi@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Oleksii Prokopchuk <prokopchuk@google.com>
Reviewed-by: Christian Höppner <hoeppi@google.com>
Run-TryBot: Lasse Folger <lassefolger@google.com>
-rw-r--r-- | proto/equal.go | 167 | ||||
-rw-r--r-- | reflect/protoreflect/value_equal.go | 168 | ||||
-rw-r--r-- | reflect/protoreflect/value_test.go | 50 |
3 files changed, 239 insertions, 146 deletions
diff --git a/proto/equal.go b/proto/equal.go index 67948dd1..baa1bc19 100644 --- a/proto/equal.go +++ b/proto/equal.go @@ -5,30 +5,34 @@ package proto import ( - "bytes" - "math" "reflect" - "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/reflect/protoreflect" ) -// Equal reports whether two messages are equal. -// If two messages marshal to the same bytes under deterministic serialization, -// then Equal is guaranteed to report true. +// Equal reports whether two messages are equal, +// by recursively comparing the fields of the message. +// +// - Bytes fields are equal if they contain identical bytes. +// Empty bytes (regardless of nil-ness) are considered equal. +// +// - Floating-point fields are equal if they contain the same value. +// Unlike the == operator, a NaN is equal to another NaN. +// +// - Other scalar fields are equal if they contain the same value. +// +// - Message fields are equal if they have +// the same set of populated known and extension field values, and +// the same set of unknown fields values. // -// Two messages are equal if they belong to the same message descriptor, -// have the same set of populated known and extension field values, -// and the same set of unknown fields values. If either of the top-level -// messages are invalid, then Equal reports true only if both are invalid. +// - Lists are equal if they are the same length and +// each corresponding element is equal. // -// Scalar values are compared with the equivalent of the == operator in Go, -// except bytes values which are compared using bytes.Equal and -// floating point values which specially treat NaNs as equal. -// Message values are compared by recursively calling Equal. -// Lists are equal if each element value is also equal. -// Maps are equal if they have the same set of keys, where the pair of values -// for each key is also equal. +// - Maps are equal if they have the same set of keys and +// the corresponding value for each key is equal. +// +// If two messages marshal to the same bytes under deterministic serialization, +// then Equal is guaranteed to report true. func Equal(x, y Message) bool { if x == nil || y == nil { return x == nil && y == nil @@ -42,130 +46,7 @@ func Equal(x, y Message) bool { if mx.IsValid() != my.IsValid() { return false } - return equalMessage(mx, my) -} - -// equalMessage compares two messages. -func equalMessage(mx, my protoreflect.Message) bool { - if mx.Descriptor() != my.Descriptor() { - return false - } - - nx := 0 - equal := true - mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool { - nx++ - vy := my.Get(fd) - equal = my.Has(fd) && equalField(fd, vx, vy) - return equal - }) - if !equal { - return false - } - ny := 0 - my.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool { - ny++ - return true - }) - if nx != ny { - return false - } - - return equalUnknown(mx.GetUnknown(), my.GetUnknown()) -} - -// equalField compares two fields. -func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { - switch { - case fd.IsList(): - return equalList(fd, x.List(), y.List()) - case fd.IsMap(): - return equalMap(fd, x.Map(), y.Map()) - default: - return equalValue(fd, x, y) - } -} - -// equalMap compares two maps. -func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool { - if x.Len() != y.Len() { - return false - } - equal := true - x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool { - vy := y.Get(k) - equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy) - return equal - }) - return equal -} - -// equalList compares two lists. -func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool { - if x.Len() != y.Len() { - return false - } - for i := x.Len() - 1; i >= 0; i-- { - if !equalValue(fd, x.Get(i), y.Get(i)) { - return false - } - } - return true -} - -// equalValue compares two singular values. -func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { - switch fd.Kind() { - case protoreflect.BoolKind: - return x.Bool() == y.Bool() - case protoreflect.EnumKind: - return x.Enum() == y.Enum() - case protoreflect.Int32Kind, protoreflect.Sint32Kind, - protoreflect.Int64Kind, protoreflect.Sint64Kind, - protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind: - return x.Int() == y.Int() - case protoreflect.Uint32Kind, protoreflect.Uint64Kind, - protoreflect.Fixed32Kind, protoreflect.Fixed64Kind: - return x.Uint() == y.Uint() - case protoreflect.FloatKind, protoreflect.DoubleKind: - fx := x.Float() - fy := y.Float() - if math.IsNaN(fx) || math.IsNaN(fy) { - return math.IsNaN(fx) && math.IsNaN(fy) - } - return fx == fy - case protoreflect.StringKind: - return x.String() == y.String() - case protoreflect.BytesKind: - return bytes.Equal(x.Bytes(), y.Bytes()) - case protoreflect.MessageKind, protoreflect.GroupKind: - return equalMessage(x.Message(), y.Message()) - default: - return x.Interface() == y.Interface() - } -} - -// equalUnknown compares unknown fields by direct comparison on the raw bytes -// of each individual field number. -func equalUnknown(x, y protoreflect.RawFields) bool { - if len(x) != len(y) { - return false - } - if bytes.Equal([]byte(x), []byte(y)) { - return true - } - - mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields) - my := make(map[protoreflect.FieldNumber]protoreflect.RawFields) - for len(x) > 0 { - fnum, _, n := protowire.ConsumeField(x) - mx[fnum] = append(mx[fnum], x[:n]...) - x = x[n:] - } - for len(y) > 0 { - fnum, _, n := protowire.ConsumeField(y) - my[fnum] = append(my[fnum], y[:n]...) - y = y[n:] - } - return reflect.DeepEqual(mx, my) + vx := protoreflect.ValueOfMessage(mx) + vy := protoreflect.ValueOfMessage(my) + return vx.Equal(vy) } diff --git a/reflect/protoreflect/value_equal.go b/reflect/protoreflect/value_equal.go new file mode 100644 index 00000000..59165254 --- /dev/null +++ b/reflect/protoreflect/value_equal.go @@ -0,0 +1,168 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protoreflect + +import ( + "bytes" + "fmt" + "math" + "reflect" + + "google.golang.org/protobuf/encoding/protowire" +) + +// Equal reports whether v1 and v2 are recursively equal. +// +// - Values of different types are always unequal. +// +// - Bytes values are equal if they contain identical bytes. +// Empty bytes (regardless of nil-ness) are considered equal. +// +// - Floating point values are equal if they contain the same value. +// Unlike the == operator, a NaN is equal to another NaN. +// +// - Enums are equal if they contain the same number. +// Since Value does not contain an enum descriptor, +// enum values do not consider the type of the enum. +// +// - Other scalar values are equal if they contain the same value. +// +// - Message values are equal if they belong to the same message descriptor, +// have the same set of populated known and extension field values, +// and the same set of unknown fields values. +// +// - Lists are equal if they are the same length and +// each corresponding element is equal. +// +// - Maps are equal if they have the same set of keys and +// the corresponding value for each key is equal. +func (v1 Value) Equal(v2 Value) bool { + return equalValue(v1, v2) +} + +func equalValue(x, y Value) bool { + eqType := x.typ == y.typ + switch x.typ { + case nilType: + return eqType + case boolType: + return eqType && x.Bool() == y.Bool() + case int32Type, int64Type: + return eqType && x.Int() == y.Int() + case uint32Type, uint64Type: + return eqType && x.Uint() == y.Uint() + case float32Type, float64Type: + return eqType && equalFloat(x.Float(), y.Float()) + case stringType: + return eqType && x.String() == y.String() + case bytesType: + return eqType && bytes.Equal(x.Bytes(), y.Bytes()) + case enumType: + return eqType && x.Enum() == y.Enum() + default: + switch x := x.Interface().(type) { + case Message: + y, ok := y.Interface().(Message) + return ok && equalMessage(x, y) + case List: + y, ok := y.Interface().(List) + return ok && equalList(x, y) + case Map: + y, ok := y.Interface().(Map) + return ok && equalMap(x, y) + default: + panic(fmt.Sprintf("unknown type: %T", x)) + } + } +} + +// equalFloat compares two floats, where NaNs are treated as equal. +func equalFloat(x, y float64) bool { + if math.IsNaN(x) || math.IsNaN(y) { + return math.IsNaN(x) && math.IsNaN(y) + } + return x == y +} + +// equalMessage compares two messages. +func equalMessage(mx, my Message) bool { + if mx.Descriptor() != my.Descriptor() { + return false + } + + nx := 0 + equal := true + mx.Range(func(fd FieldDescriptor, vx Value) bool { + nx++ + vy := my.Get(fd) + equal = my.Has(fd) && equalValue(vx, vy) + return equal + }) + if !equal { + return false + } + ny := 0 + my.Range(func(fd FieldDescriptor, vx Value) bool { + ny++ + return true + }) + if nx != ny { + return false + } + + return equalUnknown(mx.GetUnknown(), my.GetUnknown()) +} + +// equalList compares two lists. +func equalList(x, y List) bool { + if x.Len() != y.Len() { + return false + } + for i := x.Len() - 1; i >= 0; i-- { + if !equalValue(x.Get(i), y.Get(i)) { + return false + } + } + return true +} + +// equalMap compares two maps. +func equalMap(x, y Map) bool { + if x.Len() != y.Len() { + return false + } + equal := true + x.Range(func(k MapKey, vx Value) bool { + vy := y.Get(k) + equal = y.Has(k) && equalValue(vx, vy) + return equal + }) + return equal +} + +// equalUnknown compares unknown fields by direct comparison on the raw bytes +// of each individual field number. +func equalUnknown(x, y RawFields) bool { + if len(x) != len(y) { + return false + } + if bytes.Equal([]byte(x), []byte(y)) { + return true + } + + mx := make(map[FieldNumber]RawFields) + my := make(map[FieldNumber]RawFields) + for len(x) > 0 { + fnum, _, n := protowire.ConsumeField(x) + mx[fnum] = append(mx[fnum], x[:n]...) + x = x[n:] + } + for len(y) > 0 { + fnum, _, n := protowire.ConsumeField(y) + my[fnum] = append(my[fnum], y[:n]...) + y = y[n:] + } + return reflect.DeepEqual(mx, my) +} diff --git a/reflect/protoreflect/value_test.go b/reflect/protoreflect/value_test.go index c6bbb584..d3879001 100644 --- a/reflect/protoreflect/value_test.go +++ b/reflect/protoreflect/value_test.go @@ -11,10 +11,13 @@ import ( "testing" ) +var ( + fakeMessage = new(struct{ Message }) + fakeList = new(struct{ List }) + fakeMap = new(struct{ Map }) +) + func TestValue(t *testing.T) { - fakeMessage := new(struct{ Message }) - fakeList := new(struct{ List }) - fakeMap := new(struct{ Map }) tests := []struct { in Value @@ -98,6 +101,47 @@ func TestValue(t *testing.T) { } } +func TestValueEqual(t *testing.T) { + tests := []struct { + x, y Value + want bool + }{ + {Value{}, Value{}, true}, + {Value{}, ValueOfBool(true), false}, + {ValueOfBool(true), ValueOfBool(true), true}, + {ValueOfBool(true), ValueOfBool(false), false}, + {ValueOfBool(false), ValueOfInt32(0), false}, + {ValueOfInt32(0), ValueOfInt32(0), true}, + {ValueOfInt32(0), ValueOfInt32(1), false}, + {ValueOfInt32(0), ValueOfInt64(0), false}, + {ValueOfInt64(123), ValueOfInt64(123), true}, + {ValueOfFloat64(0), ValueOfFloat64(0), true}, + {ValueOfFloat64(math.NaN()), ValueOfFloat64(math.NaN()), true}, + {ValueOfFloat64(math.NaN()), ValueOfFloat64(0), false}, + {ValueOfFloat64(math.Inf(1)), ValueOfFloat64(math.Inf(1)), true}, + {ValueOfFloat64(math.Inf(-1)), ValueOfFloat64(math.Inf(1)), false}, + {ValueOfBytes(nil), ValueOfBytes(nil), true}, + {ValueOfBytes(nil), ValueOfBytes([]byte{}), true}, + {ValueOfBytes(nil), ValueOfBytes([]byte{1}), false}, + {ValueOfEnum(0), ValueOfEnum(0), true}, + {ValueOfEnum(0), ValueOfEnum(1), false}, + {ValueOfBool(false), ValueOfMessage(fakeMessage), false}, + {ValueOfMessage(fakeMessage), ValueOfList(fakeList), false}, + {ValueOfList(fakeList), ValueOfMap(fakeMap), false}, + {ValueOfMap(fakeMap), ValueOfMessage(fakeMessage), false}, + + // Composite types are not tested here. + // See proto.TestEqual. + } + + for _, tt := range tests { + got := tt.x.Equal(tt.y) + if got != tt.want { + t.Errorf("(%v).Equal(%v) = %v, want %v", tt.x, tt.y, got, tt.want) + } + } +} + func BenchmarkValue(b *testing.B) { const testdata = "The quick brown fox jumped over the lazy dog." var sink1 string |