Commit bd253f01 authored by Joshua Haberman's avatar Joshua Haberman

Fixed equality, and extended to repeated fields and maps.

parent 9cfb12bf
...@@ -1444,7 +1444,7 @@ static void putmsg(VALUE msg_rb, const Descriptor* desc, ...@@ -1444,7 +1444,7 @@ static void putmsg(VALUE msg_rb, const Descriptor* desc,
int type = TYPE(val); int type = TYPE(val);
if (type != T_DATA && type != T_NIL && is_wrapper_type_field(f)) { if (type != T_DATA && type != T_NIL && is_wrapper_type_field(f)) {
// OPT: could try to avoid expanding the wrapper here. // OPT: could try to avoid expanding the wrapper here.
val = ruby_wrapper_type(desc->layout, f, val); val = ruby_wrapper_type(field_type_class(desc->layout, f), val);
DEREF(msg, offset, VALUE) = val; DEREF(msg, offset, VALUE) = val;
} }
putsubmsg(val, f, sink, depth, emit_defaults, is_json); putsubmsg(val, f, sink, depth, emit_defaults, is_json);
......
...@@ -631,7 +631,8 @@ VALUE Map_eq(VALUE _self, VALUE _other) { ...@@ -631,7 +631,8 @@ VALUE Map_eq(VALUE _self, VALUE _other) {
return Qfalse; return Qfalse;
} }
if (!native_slot_eq(self->value_type, mem, other_mem)) { if (!native_slot_eq(self->value_type, self->value_type_class, mem,
other_mem)) {
// Present, but value not equal. // Present, but value not equal.
return Qfalse; return Qfalse;
} }
......
...@@ -131,14 +131,13 @@ bool is_wrapper_type_field(const upb_fielddef* field) { ...@@ -131,14 +131,13 @@ bool is_wrapper_type_field(const upb_fielddef* field) {
} }
// Get a new Ruby wrapper type and set the initial value // Get a new Ruby wrapper type and set the initial value
VALUE ruby_wrapper_type(const MessageLayout* layout, const upb_fielddef* field, VALUE ruby_wrapper_type(VALUE type_class, VALUE value) {
const VALUE value) { if (value != Qnil) {
if (is_wrapper_type_field(field) && value != Qnil) {
VALUE hash = rb_hash_new(); VALUE hash = rb_hash_new();
rb_hash_aset(hash, rb_str_new2("value"), value); rb_hash_aset(hash, rb_str_new2("value"), value);
{ {
VALUE args[1] = {hash}; VALUE args[1] = {hash};
return rb_class_new_instance(1, args, field_type_class(layout, field)); return rb_class_new_instance(1, args, type_class);
} }
} }
return Qnil; return Qnil;
...@@ -343,7 +342,8 @@ VALUE Message_method_missing(int argc, VALUE* argv, VALUE _self) { ...@@ -343,7 +342,8 @@ VALUE Message_method_missing(int argc, VALUE* argv, VALUE _self) {
return value; return value;
} }
} else if (accessor_type == METHOD_WRAPPER_SETTER) { } else if (accessor_type == METHOD_WRAPPER_SETTER) {
VALUE wrapper = ruby_wrapper_type(self->descriptor->layout, f, argv[1]); VALUE wrapper = ruby_wrapper_type(
field_type_class(self->descriptor->layout, f), argv[1]);
layout_set(self->descriptor->layout, Message_data(self), f, wrapper); layout_set(self->descriptor->layout, Message_data(self), f, wrapper);
return Qnil; return Qnil;
} else if (accessor_type == METHOD_ENUM_GETTER) { } else if (accessor_type == METHOD_ENUM_GETTER) {
......
...@@ -364,7 +364,8 @@ void native_slot_init(upb_fieldtype_t type, void* memory); ...@@ -364,7 +364,8 @@ void native_slot_init(upb_fieldtype_t type, void* memory);
void native_slot_mark(upb_fieldtype_t type, void* memory); void native_slot_mark(upb_fieldtype_t type, void* memory);
void native_slot_dup(upb_fieldtype_t type, void* to, void* from); void native_slot_dup(upb_fieldtype_t type, void* to, void* from);
void native_slot_deep_copy(upb_fieldtype_t type, void* to, void* from); void native_slot_deep_copy(upb_fieldtype_t type, void* to, void* from);
bool native_slot_eq(upb_fieldtype_t type, void* mem1, void* mem2); bool native_slot_eq(upb_fieldtype_t type, VALUE type_class, void* mem1,
void* mem2);
VALUE native_slot_encode_and_freeze_string(upb_fieldtype_t type, VALUE value); VALUE native_slot_encode_and_freeze_string(upb_fieldtype_t type, VALUE value);
void native_slot_check_int_range_precision(const char* name, upb_fieldtype_t type, VALUE value); void native_slot_check_int_range_precision(const char* name, upb_fieldtype_t type, VALUE value);
...@@ -557,8 +558,7 @@ VALUE layout_hash(MessageLayout* layout, void* storage); ...@@ -557,8 +558,7 @@ VALUE layout_hash(MessageLayout* layout, void* storage);
VALUE layout_inspect(MessageLayout* layout, void* storage); VALUE layout_inspect(MessageLayout* layout, void* storage);
bool is_wrapper_type_field(const upb_fielddef* field); bool is_wrapper_type_field(const upb_fielddef* field);
VALUE ruby_wrapper_type(const MessageLayout* layout, const upb_fielddef* field, VALUE ruby_wrapper_type(VALUE type_class, VALUE value);
const VALUE value);
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Message class creation. // Message class creation.
......
...@@ -451,7 +451,8 @@ VALUE RepeatedField_eq(VALUE _self, VALUE _other) { ...@@ -451,7 +451,8 @@ VALUE RepeatedField_eq(VALUE _self, VALUE _other) {
for (i = 0; i < self->size; i++, off += elem_size) { for (i = 0; i < self->size; i++, off += elem_size) {
void* self_mem = ((uint8_t *)self->elements) + off; void* self_mem = ((uint8_t *)self->elements) + off;
void* other_mem = ((uint8_t *)other->elements) + off; void* other_mem = ((uint8_t *)other->elements) + off;
if (!native_slot_eq(field_type, self_mem, other_mem)) { if (!native_slot_eq(field_type, self->field_type_class, self_mem,
other_mem)) {
return Qfalse; return Qfalse;
} }
} }
......
...@@ -294,8 +294,17 @@ VALUE native_slot_get(upb_fieldtype_t type, ...@@ -294,8 +294,17 @@ VALUE native_slot_get(upb_fieldtype_t type,
return DEREF(memory, int8_t) ? Qtrue : Qfalse; return DEREF(memory, int8_t) ? Qtrue : Qfalse;
case UPB_TYPE_STRING: case UPB_TYPE_STRING:
case UPB_TYPE_BYTES: case UPB_TYPE_BYTES:
case UPB_TYPE_MESSAGE:
return DEREF(memory, VALUE); return DEREF(memory, VALUE);
case UPB_TYPE_MESSAGE: {
VALUE val = DEREF(memory, VALUE);
int type = TYPE(val);
if (type != T_DATA && type != T_NIL) {
// This must be a wrapper type.
val = ruby_wrapper_type(type_class, val);
DEREF(memory, VALUE) = val;
}
return val;
}
case UPB_TYPE_ENUM: { case UPB_TYPE_ENUM: {
int32_t val = DEREF(memory, int32_t); int32_t val = DEREF(memory, int32_t);
VALUE symbol = enum_lookup(type_class, INT2NUM(val)); VALUE symbol = enum_lookup(type_class, INT2NUM(val));
...@@ -392,13 +401,14 @@ void native_slot_deep_copy(upb_fieldtype_t type, void* to, void* from) { ...@@ -392,13 +401,14 @@ void native_slot_deep_copy(upb_fieldtype_t type, void* to, void* from) {
} }
} }
bool native_slot_eq(upb_fieldtype_t type, void* mem1, void* mem2) { bool native_slot_eq(upb_fieldtype_t type, VALUE type_class, void* mem1,
void* mem2) {
switch (type) { switch (type) {
case UPB_TYPE_STRING: case UPB_TYPE_STRING:
case UPB_TYPE_BYTES: case UPB_TYPE_BYTES:
case UPB_TYPE_MESSAGE: { case UPB_TYPE_MESSAGE: {
VALUE val1 = DEREF(mem1, VALUE); VALUE val1 = native_slot_get(type, type_class, mem1);
VALUE val2 = DEREF(mem2, VALUE); VALUE val2 = native_slot_get(type, type_class, mem2);
VALUE ret = rb_funcall(val1, rb_intern("=="), 1, val2); VALUE ret = rb_funcall(val1, rb_intern("=="), 1, val2);
return ret == Qtrue; return ret == Qtrue;
} }
...@@ -843,15 +853,8 @@ VALUE layout_get(MessageLayout* layout, ...@@ -843,15 +853,8 @@ VALUE layout_get(MessageLayout* layout,
} else if (!field_set) { } else if (!field_set) {
return layout_get_default(field); return layout_get_default(field);
} else { } else {
VALUE type_class = field_type_class(layout, field); return native_slot_get(upb_fielddef_type(field),
VALUE val = native_slot_get(upb_fielddef_type(field), type_class, memory); field_type_class(layout, field), memory);
int type = TYPE(val);
if (type != T_DATA && type != T_NIL && is_wrapper_type_field(field)) {
val = ruby_wrapper_type(layout, field, val);
native_slot_set(upb_fielddef_name(field), upb_fielddef_type(field),
type_class, memory, val);
}
return val;
} }
} }
...@@ -1068,7 +1071,8 @@ VALUE layout_eq(MessageLayout* layout, void* msg1, void* msg2) { ...@@ -1068,7 +1071,8 @@ VALUE layout_eq(MessageLayout* layout, void* msg1, void* msg2) {
if (*msg1_oneof_case != *msg2_oneof_case || if (*msg1_oneof_case != *msg2_oneof_case ||
(slot_read_oneof_case(layout, msg1, oneof) == (slot_read_oneof_case(layout, msg1, oneof) ==
upb_fielddef_number(field) && upb_fielddef_number(field) &&
!native_slot_eq(upb_fielddef_type(field), msg1_memory, !native_slot_eq(upb_fielddef_type(field),
field_type_class(layout, field), msg1_memory,
msg2_memory))) { msg2_memory))) {
return Qfalse; return Qfalse;
} }
...@@ -1085,7 +1089,9 @@ VALUE layout_eq(MessageLayout* layout, void* msg1, void* msg2) { ...@@ -1085,7 +1089,9 @@ VALUE layout_eq(MessageLayout* layout, void* msg1, void* msg2) {
} else { } else {
if (slot_is_hasbit_set(layout, msg1, field) != if (slot_is_hasbit_set(layout, msg1, field) !=
slot_is_hasbit_set(layout, msg2, field) || slot_is_hasbit_set(layout, msg2, field) ||
!native_slot_eq(upb_fielddef_type(field), msg1_memory, msg2_memory)) { !native_slot_eq(upb_fielddef_type(field),
field_type_class(layout, field), msg1_memory,
msg2_memory)) {
return Qfalse; return Qfalse;
} }
} }
......
...@@ -128,6 +128,18 @@ message Wrapper { ...@@ -128,6 +128,18 @@ message Wrapper {
oneof a_oneof { oneof a_oneof {
string oneof_string = 10; string oneof_string = 10;
} }
// Repeated wrappers don't really make sense, but we still need to make sure
// they work and don't crash.
repeated google.protobuf.DoubleValue repeated_double = 11;
repeated google.protobuf.FloatValue repeated_float = 12;
repeated google.protobuf.Int32Value repeated_int32 = 13;
repeated google.protobuf.Int64Value repeated_int64 = 14;
repeated google.protobuf.UInt32Value repeated_uint32 = 15;
repeated google.protobuf.UInt64Value repeated_uint64 = 16;
repeated google.protobuf.BoolValue repeated_bool = 17;
repeated google.protobuf.StringValue repeated_string = 18;
repeated google.protobuf.BytesValue repeated_bytes = 19;
} }
message TimeMessage { message TimeMessage {
......
...@@ -135,6 +135,18 @@ message Wrapper { ...@@ -135,6 +135,18 @@ message Wrapper {
oneof a_oneof { oneof a_oneof {
string oneof_string = 10; string oneof_string = 10;
} }
// Repeated wrappers don't really make sense, but we still need to make sure
// they work and don't crash.
repeated google.protobuf.DoubleValue repeated_double = 11;
repeated google.protobuf.FloatValue repeated_float = 12;
repeated google.protobuf.Int32Value repeated_int32 = 13;
repeated google.protobuf.Int64Value repeated_int64 = 14;
repeated google.protobuf.UInt32Value repeated_uint32 = 15;
repeated google.protobuf.UInt64Value repeated_uint64 = 16;
repeated google.protobuf.BoolValue repeated_bool = 17;
repeated google.protobuf.StringValue repeated_string = 18;
repeated google.protobuf.BytesValue repeated_bytes = 19;
} }
message TimeMessage { message TimeMessage {
......
...@@ -1331,9 +1331,7 @@ module CommonTests ...@@ -1331,9 +1331,7 @@ module CommonTests
# Test that the lazy form compares equal to the expanded form. # Test that the lazy form compares equal to the expanded form.
m5 = proto_module::Wrapper::decode(serialized2) m5 = proto_module::Wrapper::decode(serialized2)
assert_equal m5, m
# This doesn't work yet.
# assert_equal m5, m
end end
def test_wrapper_setters_as_value def test_wrapper_setters_as_value
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment