Commit 44bd6bda authored by Joshua Haberman's avatar Joshua Haberman Committed by GitHub

Merge pull request #1821 from haberman/rubyfreezestr

Ruby: encode and freeze strings when the are assigned or decoded.
parents 06220303 d07a9963
...@@ -54,7 +54,7 @@ VALUE noleak_rb_str_cat(VALUE rb_str, const char *str, long len) { ...@@ -54,7 +54,7 @@ VALUE noleak_rb_str_cat(VALUE rb_str, const char *str, long len) {
static const void* newhandlerdata(upb_handlers* h, uint32_t ofs) { static const void* newhandlerdata(upb_handlers* h, uint32_t ofs) {
size_t* hd_ofs = ALLOC(size_t); size_t* hd_ofs = ALLOC(size_t);
*hd_ofs = ofs; *hd_ofs = ofs;
upb_handlers_addcleanup(h, hd_ofs, free); upb_handlers_addcleanup(h, hd_ofs, xfree);
return hd_ofs; return hd_ofs;
} }
...@@ -69,7 +69,7 @@ static const void *newsubmsghandlerdata(upb_handlers* h, uint32_t ofs, ...@@ -69,7 +69,7 @@ static const void *newsubmsghandlerdata(upb_handlers* h, uint32_t ofs,
submsg_handlerdata_t *hd = ALLOC(submsg_handlerdata_t); submsg_handlerdata_t *hd = ALLOC(submsg_handlerdata_t);
hd->ofs = ofs; hd->ofs = ofs;
hd->md = upb_fielddef_msgsubdef(f); hd->md = upb_fielddef_msgsubdef(f);
upb_handlers_addcleanup(h, hd, free); upb_handlers_addcleanup(h, hd, xfree);
return hd; return hd;
} }
...@@ -99,7 +99,7 @@ static const void *newoneofhandlerdata(upb_handlers *h, ...@@ -99,7 +99,7 @@ static const void *newoneofhandlerdata(upb_handlers *h,
} else { } else {
hd->md = NULL; hd->md = NULL;
} }
upb_handlers_addcleanup(h, hd, free); upb_handlers_addcleanup(h, hd, xfree);
return hd; return hd;
} }
...@@ -135,7 +135,7 @@ static void* appendstr_handler(void *closure, ...@@ -135,7 +135,7 @@ static void* appendstr_handler(void *closure,
VALUE ary = (VALUE)closure; VALUE ary = (VALUE)closure;
VALUE str = rb_str_new2(""); VALUE str = rb_str_new2("");
rb_enc_associate(str, kRubyStringUtf8Encoding); rb_enc_associate(str, kRubyStringUtf8Encoding);
RepeatedField_push(ary, str); RepeatedField_push_native(ary, &str);
return (void*)str; return (void*)str;
} }
...@@ -146,7 +146,7 @@ static void* appendbytes_handler(void *closure, ...@@ -146,7 +146,7 @@ static void* appendbytes_handler(void *closure,
VALUE ary = (VALUE)closure; VALUE ary = (VALUE)closure;
VALUE str = rb_str_new2(""); VALUE str = rb_str_new2("");
rb_enc_associate(str, kRubyString8bitEncoding); rb_enc_associate(str, kRubyString8bitEncoding);
RepeatedField_push(ary, str); RepeatedField_push_native(ary, &str);
return (void*)str; return (void*)str;
} }
...@@ -182,6 +182,23 @@ static size_t stringdata_handler(void* closure, const void* hd, ...@@ -182,6 +182,23 @@ static size_t stringdata_handler(void* closure, const void* hd,
return len; return len;
} }
static bool stringdata_end_handler(void* closure, const void* hd) {
MessageHeader* msg = closure;
const size_t *ofs = hd;
VALUE rb_str = DEREF(msg, *ofs, VALUE);
rb_obj_freeze(rb_str);
return true;
}
static bool appendstring_end_handler(void* closure, const void* hd) {
VALUE ary = (VALUE)closure;
int size = RepeatedField_size(ary);
VALUE* last = RepeatedField_index_native(ary, size - 1);
VALUE rb_str = *last;
rb_obj_freeze(rb_str);
return true;
}
// Appends a submessage to a repeated field (a regular Ruby array for now). // Appends a submessage to a repeated field (a regular Ruby array for now).
static void *appendsubmsg_handler(void *closure, const void *hd) { static void *appendsubmsg_handler(void *closure, const void *hd) {
VALUE ary = (VALUE)closure; VALUE ary = (VALUE)closure;
...@@ -281,7 +298,7 @@ static bool endmap_handler(void *closure, const void *hd, upb_status* s) { ...@@ -281,7 +298,7 @@ static bool endmap_handler(void *closure, const void *hd, upb_status* s) {
&frame->value_storage); &frame->value_storage);
Map_index_set(frame->map, key, value); Map_index_set(frame->map, key, value);
free(frame); xfree(frame);
return true; return true;
} }
...@@ -360,6 +377,13 @@ static void *oneofbytes_handler(void *closure, ...@@ -360,6 +377,13 @@ static void *oneofbytes_handler(void *closure,
return (void*)str; return (void*)str;
} }
static bool oneofstring_end_handler(void* closure, const void* hd) {
MessageHeader* msg = closure;
const oneof_handlerdata_t *oneofdata = hd;
rb_obj_freeze(DEREF(msg, oneofdata->ofs, VALUE));
return true;
}
// Handler for a submessage field in a oneof. // Handler for a submessage field in a oneof.
static void *oneofsubmsg_handler(void *closure, static void *oneofsubmsg_handler(void *closure,
const void *hd) { const void *hd) {
...@@ -426,6 +450,7 @@ static void add_handlers_for_repeated_field(upb_handlers *h, ...@@ -426,6 +450,7 @@ static void add_handlers_for_repeated_field(upb_handlers *h,
appendbytes_handler : appendstr_handler, appendbytes_handler : appendstr_handler,
NULL); NULL);
upb_handlers_setstring(h, f, stringdata_handler, NULL); upb_handlers_setstring(h, f, stringdata_handler, NULL);
upb_handlers_setendstr(h, f, appendstring_end_handler, NULL);
break; break;
} }
case UPB_TYPE_MESSAGE: { case UPB_TYPE_MESSAGE: {
...@@ -462,6 +487,7 @@ static void add_handlers_for_singular_field(upb_handlers *h, ...@@ -462,6 +487,7 @@ static void add_handlers_for_singular_field(upb_handlers *h,
is_bytes ? bytes_handler : str_handler, is_bytes ? bytes_handler : str_handler,
&attr); &attr);
upb_handlers_setstring(h, f, stringdata_handler, &attr); upb_handlers_setstring(h, f, stringdata_handler, &attr);
upb_handlers_setendstr(h, f, stringdata_end_handler, &attr);
upb_handlerattr_uninit(&attr); upb_handlerattr_uninit(&attr);
break; break;
} }
...@@ -484,7 +510,7 @@ static void add_handlers_for_mapfield(upb_handlers* h, ...@@ -484,7 +510,7 @@ static void add_handlers_for_mapfield(upb_handlers* h,
map_handlerdata_t* hd = new_map_handlerdata(offset, map_msgdef, desc); map_handlerdata_t* hd = new_map_handlerdata(offset, map_msgdef, desc);
upb_handlerattr attr = UPB_HANDLERATTR_INITIALIZER; upb_handlerattr attr = UPB_HANDLERATTR_INITIALIZER;
upb_handlers_addcleanup(h, hd, free); upb_handlers_addcleanup(h, hd, xfree);
upb_handlerattr_sethandlerdata(&attr, hd); upb_handlerattr_sethandlerdata(&attr, hd);
upb_handlers_setstartsubmsg(h, fielddef, startmapentry_handler, &attr); upb_handlers_setstartsubmsg(h, fielddef, startmapentry_handler, &attr);
upb_handlerattr_uninit(&attr); upb_handlerattr_uninit(&attr);
...@@ -499,7 +525,7 @@ static void add_handlers_for_mapentry(const upb_msgdef* msgdef, ...@@ -499,7 +525,7 @@ static void add_handlers_for_mapentry(const upb_msgdef* msgdef,
map_handlerdata_t* hd = new_map_handlerdata(0, msgdef, desc); map_handlerdata_t* hd = new_map_handlerdata(0, msgdef, desc);
upb_handlerattr attr = UPB_HANDLERATTR_INITIALIZER; upb_handlerattr attr = UPB_HANDLERATTR_INITIALIZER;
upb_handlers_addcleanup(h, hd, free); upb_handlers_addcleanup(h, hd, xfree);
upb_handlerattr_sethandlerdata(&attr, hd); upb_handlerattr_sethandlerdata(&attr, hd);
upb_handlers_setendmsg(h, endmap_handler, &attr); upb_handlers_setendmsg(h, endmap_handler, &attr);
...@@ -546,6 +572,7 @@ static void add_handlers_for_oneof_field(upb_handlers *h, ...@@ -546,6 +572,7 @@ static void add_handlers_for_oneof_field(upb_handlers *h,
oneofbytes_handler : oneofstr_handler, oneofbytes_handler : oneofstr_handler,
&attr); &attr);
upb_handlers_setstring(h, f, stringdata_handler, NULL); upb_handlers_setstring(h, f, stringdata_handler, NULL);
upb_handlers_setendstr(h, f, oneofstring_end_handler, &attr);
break; break;
} }
case UPB_TYPE_MESSAGE: { case UPB_TYPE_MESSAGE: {
...@@ -863,9 +890,13 @@ static void putstr(VALUE str, const upb_fielddef *f, upb_sink *sink) { ...@@ -863,9 +890,13 @@ static void putstr(VALUE str, const upb_fielddef *f, upb_sink *sink) {
assert(BUILTIN_TYPE(str) == RUBY_T_STRING); assert(BUILTIN_TYPE(str) == RUBY_T_STRING);
// Ensure that the string has the correct encoding. We also check at field-set // We should be guaranteed that the string has the correct encoding because
// time, but the user may have mutated the string object since then. // we ensured this at assignment time and then froze the string.
native_slot_validate_string_encoding(upb_fielddef_type(f), str); if (upb_fielddef_type(f) == UPB_TYPE_STRING) {
assert(rb_enc_from_index(ENCODING_GET(value)) == kRubyStringUtf8Encoding);
} else {
assert(rb_enc_from_index(ENCODING_GET(value)) == kRubyString8bitEncoding);
}
upb_sink_startstr(sink, getsel(f, UPB_HANDLER_STARTSTR), RSTRING_LEN(str), upb_sink_startstr(sink, getsel(f, UPB_HANDLER_STARTSTR), RSTRING_LEN(str),
&subsink); &subsink);
......
...@@ -63,16 +63,16 @@ ...@@ -63,16 +63,16 @@
// construct a key byte sequence if needed. |out_key| and |out_length| provide // construct a key byte sequence if needed. |out_key| and |out_length| provide
// the resulting key data/length. // the resulting key data/length.
#define TABLE_KEY_BUF_LENGTH 8 // sizeof(uint64_t) #define TABLE_KEY_BUF_LENGTH 8 // sizeof(uint64_t)
static void table_key(Map* self, VALUE key, static VALUE table_key(Map* self, VALUE key,
char* buf, char* buf,
const char** out_key, const char** out_key,
size_t* out_length) { size_t* out_length) {
switch (self->key_type) { switch (self->key_type) {
case UPB_TYPE_BYTES: case UPB_TYPE_BYTES:
case UPB_TYPE_STRING: case UPB_TYPE_STRING:
// Strings: use string content directly. // Strings: use string content directly.
Check_Type(key, T_STRING); Check_Type(key, T_STRING);
native_slot_validate_string_encoding(self->key_type, key); key = native_slot_encode_and_freeze_string(self->key_type, key);
*out_key = RSTRING_PTR(key); *out_key = RSTRING_PTR(key);
*out_length = RSTRING_LEN(key); *out_length = RSTRING_LEN(key);
break; break;
...@@ -93,6 +93,8 @@ static void table_key(Map* self, VALUE key, ...@@ -93,6 +93,8 @@ static void table_key(Map* self, VALUE key,
assert(false); assert(false);
break; break;
} }
return key;
} }
static VALUE table_key_to_ruby(Map* self, const char* buf, size_t length) { static VALUE table_key_to_ruby(Map* self, const char* buf, size_t length) {
...@@ -357,7 +359,7 @@ VALUE Map_index(VALUE _self, VALUE key) { ...@@ -357,7 +359,7 @@ VALUE Map_index(VALUE _self, VALUE key) {
const char* keyval = NULL; const char* keyval = NULL;
size_t length = 0; size_t length = 0;
upb_value v; upb_value v;
table_key(self, key, keybuf, &keyval, &length); key = table_key(self, key, keybuf, &keyval, &length);
if (upb_strtable_lookup2(&self->table, keyval, length, &v)) { if (upb_strtable_lookup2(&self->table, keyval, length, &v)) {
void* mem = value_memory(&v); void* mem = value_memory(&v);
...@@ -383,7 +385,7 @@ VALUE Map_index_set(VALUE _self, VALUE key, VALUE value) { ...@@ -383,7 +385,7 @@ VALUE Map_index_set(VALUE _self, VALUE key, VALUE value) {
size_t length = 0; size_t length = 0;
upb_value v; upb_value v;
void* mem; void* mem;
table_key(self, key, keybuf, &keyval, &length); key = table_key(self, key, keybuf, &keyval, &length);
mem = value_memory(&v); mem = value_memory(&v);
native_slot_set(self->value_type, self->value_type_class, mem, value); native_slot_set(self->value_type, self->value_type_class, mem, value);
...@@ -411,7 +413,7 @@ VALUE Map_has_key(VALUE _self, VALUE key) { ...@@ -411,7 +413,7 @@ VALUE Map_has_key(VALUE _self, VALUE key) {
char keybuf[TABLE_KEY_BUF_LENGTH]; char keybuf[TABLE_KEY_BUF_LENGTH];
const char* keyval = NULL; const char* keyval = NULL;
size_t length = 0; size_t length = 0;
table_key(self, key, keybuf, &keyval, &length); key = table_key(self, key, keybuf, &keyval, &length);
if (upb_strtable_lookup2(&self->table, keyval, length, NULL)) { if (upb_strtable_lookup2(&self->table, keyval, length, NULL)) {
return Qtrue; return Qtrue;
...@@ -434,7 +436,7 @@ VALUE Map_delete(VALUE _self, VALUE key) { ...@@ -434,7 +436,7 @@ VALUE Map_delete(VALUE _self, VALUE key) {
const char* keyval = NULL; const char* keyval = NULL;
size_t length = 0; size_t length = 0;
upb_value v; upb_value v;
table_key(self, key, keybuf, &keyval, &length); key = table_key(self, key, keybuf, &keyval, &length);
if (upb_strtable_remove2(&self->table, keyval, length, &v)) { if (upb_strtable_remove2(&self->table, keyval, length, &v)) {
void* mem = value_memory(&v); void* mem = value_memory(&v);
......
...@@ -313,7 +313,7 @@ void native_slot_dup(upb_fieldtype_t type, void* to, void* from); ...@@ -313,7 +313,7 @@ void native_slot_dup(upb_fieldtype_t type, void* to, void* from);
void native_slot_deep_copy(upb_fieldtype_t type, void* to, void* from); void native_slot_deep_copy(upb_fieldtype_t type, void* to, void* from);
bool native_slot_eq(upb_fieldtype_t type, void* mem1, void* mem2); bool native_slot_eq(upb_fieldtype_t type, void* mem1, void* mem2);
void native_slot_validate_string_encoding(upb_fieldtype_t type, VALUE value); VALUE native_slot_encode_and_freeze_string(upb_fieldtype_t type, VALUE value);
void native_slot_check_int_range_precision(upb_fieldtype_t type, VALUE value); void native_slot_check_int_range_precision(upb_fieldtype_t type, VALUE value);
extern rb_encoding* kRubyStringUtf8Encoding; extern rb_encoding* kRubyStringUtf8Encoding;
...@@ -366,6 +366,7 @@ RepeatedField* ruby_to_RepeatedField(VALUE value); ...@@ -366,6 +366,7 @@ RepeatedField* ruby_to_RepeatedField(VALUE value);
VALUE RepeatedField_each(VALUE _self); VALUE RepeatedField_each(VALUE _self);
VALUE RepeatedField_index(int argc, VALUE* argv, VALUE _self); VALUE RepeatedField_index(int argc, VALUE* argv, VALUE _self);
void* RepeatedField_index_native(VALUE _self, int index); void* RepeatedField_index_native(VALUE _self, int index);
int RepeatedField_size(VALUE _self);
VALUE RepeatedField_index_set(VALUE _self, VALUE _index, VALUE val); VALUE RepeatedField_index_set(VALUE _self, VALUE _index, VALUE val);
void RepeatedField_reserve(RepeatedField* self, int new_size); void RepeatedField_reserve(RepeatedField* self, int new_size);
VALUE RepeatedField_push(VALUE _self, VALUE val); VALUE RepeatedField_push(VALUE _self, VALUE val);
......
...@@ -244,6 +244,11 @@ void* RepeatedField_index_native(VALUE _self, int index) { ...@@ -244,6 +244,11 @@ void* RepeatedField_index_native(VALUE _self, int index) {
return RepeatedField_memoryat(self, index, element_size); return RepeatedField_memoryat(self, index, element_size);
} }
int RepeatedField_size(VALUE _self) {
RepeatedField* self = ruby_to_RepeatedField(_self);
return self->size;
}
/* /*
* Private ruby method, used by RepeatedField.pop * Private ruby method, used by RepeatedField.pop
*/ */
......
...@@ -117,25 +117,24 @@ void native_slot_check_int_range_precision(upb_fieldtype_t type, VALUE val) { ...@@ -117,25 +117,24 @@ void native_slot_check_int_range_precision(upb_fieldtype_t type, VALUE val) {
} }
} }
void native_slot_validate_string_encoding(upb_fieldtype_t type, VALUE value) { VALUE native_slot_encode_and_freeze_string(upb_fieldtype_t type, VALUE value) {
bool bad_encoding = false; rb_encoding* desired_encoding = (type == UPB_TYPE_STRING) ?
rb_encoding* string_encoding = rb_enc_from_index(ENCODING_GET(value)); kRubyStringUtf8Encoding : kRubyString8bitEncoding;
if (type == UPB_TYPE_STRING) { VALUE desired_encoding_value = rb_enc_from_encoding(desired_encoding);
bad_encoding =
string_encoding != kRubyStringUtf8Encoding && // Note: this will not duplicate underlying string data unless necessary.
string_encoding != kRubyStringASCIIEncoding; value = rb_str_encode(value, desired_encoding_value, 0, Qnil);
} else {
bad_encoding = if (type == UPB_TYPE_STRING &&
string_encoding != kRubyString8bitEncoding; rb_enc_str_coderange(value) == ENC_CODERANGE_BROKEN) {
} rb_raise(rb_eEncodingError, "String is invalid UTF-8");
// Check that encoding is UTF-8 or ASCII (for string fields) or ASCII-8BIT
// (for bytes fields).
if (bad_encoding) {
rb_raise(rb_eTypeError, "Encoding for '%s' fields must be %s (was %s)",
(type == UPB_TYPE_STRING) ? "string" : "bytes",
(type == UPB_TYPE_STRING) ? "UTF-8 or ASCII" : "ASCII-8BIT",
rb_enc_name(string_encoding));
} }
// Ensure the data remains valid. Since we called #encode a moment ago,
// this does not freeze the string the user assigned.
rb_obj_freeze(value);
return value;
} }
void native_slot_set(upb_fieldtype_t type, VALUE type_class, void native_slot_set(upb_fieldtype_t type, VALUE type_class,
...@@ -181,8 +180,8 @@ void native_slot_set_value_and_case(upb_fieldtype_t type, VALUE type_class, ...@@ -181,8 +180,8 @@ void native_slot_set_value_and_case(upb_fieldtype_t type, VALUE type_class,
if (CLASS_OF(value) != rb_cString) { if (CLASS_OF(value) != rb_cString) {
rb_raise(rb_eTypeError, "Invalid argument for string field."); rb_raise(rb_eTypeError, "Invalid argument for string field.");
} }
native_slot_validate_string_encoding(type, value);
DEREF(memory, VALUE) = value; DEREF(memory, VALUE) = native_slot_encode_and_freeze_string(type, value);
break; break;
} }
case UPB_TYPE_MESSAGE: { case UPB_TYPE_MESSAGE: {
......
...@@ -11076,8 +11076,8 @@ static bool end_stringval(upb_json_parser *p) { ...@@ -11076,8 +11076,8 @@ static bool end_stringval(upb_json_parser *p) {
case UPB_TYPE_STRING: { case UPB_TYPE_STRING: {
upb_selector_t sel = getsel_for_handlertype(p, UPB_HANDLER_ENDSTR); upb_selector_t sel = getsel_for_handlertype(p, UPB_HANDLER_ENDSTR);
upb_sink_endstr(&p->top->sink, sel);
p->top--; p->top--;
upb_sink_endstr(&p->top->sink, sel);
break; break;
} }
......
...@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject { ...@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject {
*/ */
@JRubyMethod(name = "[]=") @JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) { public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass); key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass); value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
IRubyObject symbol; IRubyObject symbol;
if (valueType == Descriptors.FieldDescriptor.Type.ENUM && if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
Utils.isRubyNum(value) && Utils.isRubyNum(value) &&
......
...@@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject { ...@@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject {
break; break;
case BYTES: case BYTES:
case STRING: case STRING:
Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value); Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
RubyString str = (RubyString) value; RubyString str = (RubyString) value;
switch (fieldDescriptor.getType()) { switch (fieldDescriptor.getType()) {
case BYTES: case BYTES:
...@@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject { ...@@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject {
} }
} }
if (addValue) { if (addValue) {
Utils.checkType(context, fieldType, value, (RubyModule) typeClass); value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
this.fields.put(fieldDescriptor, value); this.fields.put(fieldDescriptor, value);
} else { } else {
this.fields.remove(fieldDescriptor); this.fields.remove(fieldDescriptor);
......
...@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject { ...@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject {
@JRubyMethod(name = "[]=") @JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) { public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
int arrIndex = normalizeArrayIndex(index); int arrIndex = normalizeArrayIndex(index);
Utils.checkType(context, fieldType, value, (RubyModule) typeClass); value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
IRubyObject defaultValue = defaultValue(context); IRubyObject defaultValue = defaultValue(context);
for (int i = this.storage.size(); i < arrIndex; i++) { for (int i = this.storage.size(); i < arrIndex; i++) {
this.storage.set(i, defaultValue); this.storage.set(i, defaultValue);
...@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject { ...@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject {
public IRubyObject push(ThreadContext context, IRubyObject value) { public IRubyObject push(ThreadContext context, IRubyObject value) {
if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE && if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
value == context.runtime.getNil())) { value == context.runtime.getNil())) {
Utils.checkType(context, fieldType, value, (RubyModule) typeClass); value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
} }
this.storage.add(value); this.storage.add(value);
return this.storage; return this.storage;
......
...@@ -64,8 +64,8 @@ public class Utils { ...@@ -64,8 +64,8 @@ public class Utils {
return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase()); return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
} }
public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
IRubyObject value, RubyModule typeClass) { IRubyObject value, RubyModule typeClass) {
Ruby runtime = context.runtime; Ruby runtime = context.runtime;
Object val; Object val;
switch(fieldType) { switch(fieldType) {
...@@ -106,7 +106,7 @@ public class Utils { ...@@ -106,7 +106,7 @@ public class Utils {
break; break;
case BYTES: case BYTES:
case STRING: case STRING:
validateStringEncoding(context.runtime, fieldType, value); value = validateStringEncoding(context, fieldType, value);
break; break;
case MESSAGE: case MESSAGE:
if (value.getMetaClass() != typeClass) { if (value.getMetaClass() != typeClass) {
...@@ -127,6 +127,7 @@ public class Utils { ...@@ -127,6 +127,7 @@ public class Utils {
default: default:
break; break;
} }
return value;
} }
public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) { public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
...@@ -148,10 +149,16 @@ public class Utils { ...@@ -148,10 +149,16 @@ public class Utils {
return runtime.newFloat((Double) value); return runtime.newFloat((Double) value);
case BOOL: case BOOL:
return (Boolean) value ? runtime.getTrue() : runtime.getFalse(); return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
case BYTES: case BYTES: {
return runtime.newString(((ByteString) value).toStringUtf8()); IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8());
case STRING: wrapped.setFrozen(true);
return runtime.newString(value.toString()); return wrapped;
}
case STRING: {
IRubyObject wrapped = runtime.newString(value.toString());
wrapped.setFrozen(true);
return wrapped;
}
default: default:
return runtime.getNil(); return runtime.getNil();
} }
...@@ -180,25 +187,21 @@ public class Utils { ...@@ -180,25 +187,21 @@ public class Utils {
} }
} }
public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) { public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
if (!(value instanceof RubyString)) if (!(value instanceof RubyString))
throw runtime.newTypeError("Invalid argument for string field."); throw context.runtime.newTypeError("Invalid argument for string field.");
Encoding encoding = ((RubyString) value).getEncoding();
switch(type) { switch(type) {
case BYTES: case BYTES:
if (encoding != ASCIIEncoding.INSTANCE) value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT"));
throw runtime.newTypeError("Encoding for bytes fields" +
" must be \"ASCII-8BIT\", but was " + encoding);
break; break;
case STRING: case STRING:
if (encoding != UTF8Encoding.INSTANCE value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8"));
&& encoding != USASCIIEncoding.INSTANCE)
throw runtime.newTypeError("Encoding for string fields" +
" must be \"UTF-8\" or \"ASCII\", but was " + encoding);
break; break;
default: default:
break; break;
} }
value.setFrozen(true);
return value;
} }
public static void checkNameAvailability(ThreadContext context, String name) { public static void checkNameAvailability(ThreadContext context, String name) {
......
...@@ -255,14 +255,17 @@ module BasicTest ...@@ -255,14 +255,17 @@ module BasicTest
m = TestMessage.new m = TestMessage.new
# Assigning a normal (ASCII or UTF8) string to a bytes field, or # Assigning a normal (ASCII or UTF8) string to a bytes field, or
# ASCII-8BIT to a string field, raises an error. # ASCII-8BIT to a string field will convert to the proper encoding.
assert_raise TypeError do m.optional_bytes = "Test string ASCII".encode!('ASCII')
m.optional_bytes = "Test string ASCII".encode!('ASCII') assert m.optional_bytes.frozen?
end assert_equal Encoding::ASCII_8BIT, m.optional_bytes.encoding
assert_raise TypeError do assert_equal "Test string ASCII", m.optional_bytes
assert_raise Encoding::UndefinedConversionError do
m.optional_bytes = "Test string UTF-8 \u0100".encode!('UTF-8') m.optional_bytes = "Test string UTF-8 \u0100".encode!('UTF-8')
end end
assert_raise TypeError do
assert_raise Encoding::UndefinedConversionError do
m.optional_string = ["FFFF"].pack('H*') m.optional_string = ["FFFF"].pack('H*')
end end
...@@ -270,11 +273,10 @@ module BasicTest ...@@ -270,11 +273,10 @@ module BasicTest
m.optional_bytes = ["FFFF"].pack('H*') m.optional_bytes = ["FFFF"].pack('H*')
m.optional_string = "\u0100" m.optional_string = "\u0100"
# strings are mutable so we can do this, but serialize should catch it. # strings are immutable so we can't do this, but serialize should catch it.
m.optional_string = "asdf".encode!('UTF-8') m.optional_string = "asdf".encode!('UTF-8')
m.optional_string.encode!('ASCII-8BIT') assert_raise RuntimeError do
assert_raise TypeError do m.optional_string.encode!('ASCII-8BIT')
data = TestMessage.encode(m)
end end
end end
...@@ -558,7 +560,7 @@ module BasicTest ...@@ -558,7 +560,7 @@ module BasicTest
assert_raise TypeError do assert_raise TypeError do
m[1] = 1 m[1] = 1
end end
assert_raise TypeError do assert_raise Encoding::UndefinedConversionError do
bytestring = ["FFFF"].pack("H*") bytestring = ["FFFF"].pack("H*")
m[bytestring] = 1 m[bytestring] = 1
end end
...@@ -566,9 +568,8 @@ module BasicTest ...@@ -566,9 +568,8 @@ module BasicTest
m = Google::Protobuf::Map.new(:bytes, :int32) m = Google::Protobuf::Map.new(:bytes, :int32)
bytestring = ["FFFF"].pack("H*") bytestring = ["FFFF"].pack("H*")
m[bytestring] = 1 m[bytestring] = 1
assert_raise TypeError do # Allowed -- we will automatically convert to ASCII-8BIT.
m["asdf"] = 1 m["asdf"] = 1
end
assert_raise TypeError do assert_raise TypeError do
m[1] = 1 m[1] = 1
end end
...@@ -853,15 +854,22 @@ module BasicTest ...@@ -853,15 +854,22 @@ module BasicTest
def test_encode_decode_helpers def test_encode_decode_helpers
m = TestMessage.new(:optional_string => 'foo', :repeated_string => ['bar1', 'bar2']) m = TestMessage.new(:optional_string => 'foo', :repeated_string => ['bar1', 'bar2'])
assert_equal 'foo', m.optional_string
assert_equal ['bar1', 'bar2'], m.repeated_string
json = m.to_json json = m.to_json
m2 = TestMessage.decode_json(json) m2 = TestMessage.decode_json(json)
assert m2.optional_string == 'foo' assert_equal 'foo', m2.optional_string
assert m2.repeated_string == ['bar1', 'bar2'] assert_equal ['bar1', 'bar2'], m2.repeated_string
if RUBY_PLATFORM != "java"
assert m2.optional_string.frozen?
assert m2.repeated_string[0].frozen?
end
proto = m.to_proto proto = m.to_proto
m2 = TestMessage.decode(proto) m2 = TestMessage.decode(proto)
assert m2.optional_string == 'foo' assert_equal 'foo', m2.optional_string
assert m2.repeated_string == ['bar1', 'bar2'] assert_equal ['bar1', 'bar2'], m2.repeated_string
end end
def test_protobuf_encode_decode_helpers def test_protobuf_encode_decode_helpers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment