Commit d55733c7 authored by Adam Greene's avatar Adam Greene

return nil if array index indicie is out of bounds

ruby arrays don't throw an exception; they return nil.  Lets do the
same!

this fix also includes the ability to use negative array indicies
parent c70b6058
...@@ -63,7 +63,7 @@ To build this Ruby extension, you will need: ...@@ -63,7 +63,7 @@ To build this Ruby extension, you will need:
To Build the JRuby extension, you will need: To Build the JRuby extension, you will need:
* Maven * Maven
* The latest version of the protobuf java library * The latest version of the protobuf java library (see ../java/README.md)
* Install JRuby via rbenv or RVM * Install JRuby via rbenv or RVM
First switch to the desired platform with rbenv or RVM. First switch to the desired platform with rbenv or RVM.
...@@ -75,7 +75,8 @@ Then install the required Ruby gems: ...@@ -75,7 +75,8 @@ Then install the required Ruby gems:
Then build the Gem: Then build the Gem:
$ rake gem $ rake
$ rake clobber_package gem
$ gem install `ls pkg/google-protobuf-*.gem` $ gem install `ls pkg/google-protobuf-*.gem`
To run the specs: To run the specs:
......
...@@ -6,6 +6,9 @@ require "rake/testtask" ...@@ -6,6 +6,9 @@ require "rake/testtask"
spec = Gem::Specification.load("google-protobuf.gemspec") spec = Gem::Specification.load("google-protobuf.gemspec")
if RUBY_PLATFORM == "java" if RUBY_PLATFORM == "java"
if `which mvn` == ''
raise ArgumentError, "maven needs to be installed"
end
task :clean do task :clean do
system("mvn clean") system("mvn clean")
end end
......
...@@ -47,6 +47,15 @@ RepeatedField* ruby_to_RepeatedField(VALUE _self) { ...@@ -47,6 +47,15 @@ RepeatedField* ruby_to_RepeatedField(VALUE _self) {
return self; return self;
} }
static int index_position(VALUE _index, RepeatedField* repeated_field) {
int index = NUM2INT(_index);
if (index < 0 && repeated_field->size > 0) {
index = repeated_field->size + index;
}
return index;
}
/* /*
* call-seq: * call-seq:
* RepeatedField.each(&block) * RepeatedField.each(&block)
...@@ -74,8 +83,7 @@ VALUE RepeatedField_each(VALUE _self) { ...@@ -74,8 +83,7 @@ VALUE RepeatedField_each(VALUE _self) {
* call-seq: * call-seq:
* RepeatedField.[](index) => value * RepeatedField.[](index) => value
* *
* Accesses the element at the given index. Throws an exception on out-of-bounds * Accesses the element at the given index. Returns nil on out-of-bounds
* errors.
*/ */
VALUE RepeatedField_index(VALUE _self, VALUE _index) { VALUE RepeatedField_index(VALUE _self, VALUE _index) {
RepeatedField* self = ruby_to_RepeatedField(_self); RepeatedField* self = ruby_to_RepeatedField(_self);
...@@ -83,9 +91,9 @@ VALUE RepeatedField_index(VALUE _self, VALUE _index) { ...@@ -83,9 +91,9 @@ VALUE RepeatedField_index(VALUE _self, VALUE _index) {
upb_fieldtype_t field_type = self->field_type; upb_fieldtype_t field_type = self->field_type;
VALUE field_type_class = self->field_type_class; VALUE field_type_class = self->field_type_class;
int index = NUM2INT(_index); int index = index_position(_index, self);
if (index < 0 || index >= self->size) { if (index < 0 || index >= self->size) {
rb_raise(rb_eRangeError, "Index out of range"); return Qnil;
} }
void* memory = (void *) (((uint8_t *)self->elements) + index * element_size); void* memory = (void *) (((uint8_t *)self->elements) + index * element_size);
...@@ -105,9 +113,9 @@ VALUE RepeatedField_index_set(VALUE _self, VALUE _index, VALUE val) { ...@@ -105,9 +113,9 @@ VALUE RepeatedField_index_set(VALUE _self, VALUE _index, VALUE val) {
VALUE field_type_class = self->field_type_class; VALUE field_type_class = self->field_type_class;
int element_size = native_slot_size(field_type); int element_size = native_slot_size(field_type);
int index = NUM2INT(_index); int index = index_position(_index, self);
if (index < 0 || index >= (INT_MAX - 1)) { if (index < 0 || index >= (INT_MAX - 1)) {
rb_raise(rb_eRangeError, "Index out of range"); return Qnil;
} }
if (index >= self->size) { if (index >= self->size) {
RepeatedField_reserve(self, index + 1); RepeatedField_reserve(self, index + 1);
......
...@@ -78,7 +78,7 @@ ...@@ -78,7 +78,7 @@
<dependency> <dependency>
<groupId>com.google.protobuf</groupId> <groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId> <artifactId>protobuf-java</artifactId>
<version>3.0.0-pre</version> <version>3.0.0-alpha-3-pre</version>
</dependency> </dependency>
</dependencies> </dependencies>
</project> </project>
...@@ -246,16 +246,15 @@ public class RubyMessage extends RubyObject { ...@@ -246,16 +246,15 @@ public class RubyMessage extends RubyObject {
public IRubyObject dup(ThreadContext context) { public IRubyObject dup(ThreadContext context) {
RubyMessage dup = (RubyMessage) metaClass.newInstance(context, Block.NULL_BLOCK); RubyMessage dup = (RubyMessage) metaClass.newInstance(context, Block.NULL_BLOCK);
IRubyObject value; IRubyObject value;
for (Descriptors.FieldDescriptor fieldDescriptor : builder.getAllFields().keySet()) { for (Descriptors.FieldDescriptor fieldDescriptor : this.descriptor.getFields()) {
if (fieldDescriptor.isRepeated()) { if (fieldDescriptor.isRepeated()) {
dup.repeatedFields.put(fieldDescriptor, getRepeatedField(context, fieldDescriptor)); dup.addRepeatedField(fieldDescriptor, this.getRepeatedField(context, fieldDescriptor));
} else if (builder.hasField(fieldDescriptor)) { } else if (fields.containsKey(fieldDescriptor)) {
dup.fields.put(fieldDescriptor, wrapField(context, fieldDescriptor, builder.getField(fieldDescriptor))); dup.fields.put(fieldDescriptor, fields.get(fieldDescriptor));
} else if (this.builder.hasField(fieldDescriptor)) {
dup.fields.put(fieldDescriptor, wrapField(context, fieldDescriptor, this.builder.getField(fieldDescriptor)));
} }
} }
for (Descriptors.FieldDescriptor fieldDescriptor : fields.keySet()) {
dup.fields.put(fieldDescriptor, fields.get(fieldDescriptor));
}
for (Descriptors.FieldDescriptor fieldDescriptor : maps.keySet()) { for (Descriptors.FieldDescriptor fieldDescriptor : maps.keySet()) {
dup.maps.put(fieldDescriptor, maps.get(fieldDescriptor)); dup.maps.put(fieldDescriptor, maps.get(fieldDescriptor));
} }
...@@ -411,6 +410,7 @@ public class RubyMessage extends RubyObject { ...@@ -411,6 +410,7 @@ public class RubyMessage extends RubyObject {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
ret.push(context, wrapField(context, fieldDescriptor, this.builder.getRepeatedField(fieldDescriptor, i))); ret.push(context, wrapField(context, fieldDescriptor, this.builder.getRepeatedField(fieldDescriptor, i)));
} }
addRepeatedField(fieldDescriptor, ret);
return ret; return ret;
} }
......
...@@ -108,8 +108,9 @@ public class RubyRepeatedField extends RubyObject { ...@@ -108,8 +108,9 @@ public class RubyRepeatedField extends RubyObject {
*/ */
@JRubyMethod(name = "[]=") @JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) { public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
int arrIndex = normalizeArrayIndex(index);
Utils.checkType(context, fieldType, value, (RubyModule) typeClass); Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
this.storage.set(RubyNumeric.num2int(index), value); this.storage.set(arrIndex, value);
return context.runtime.getNil(); return context.runtime.getNil();
} }
...@@ -117,12 +118,15 @@ public class RubyRepeatedField extends RubyObject { ...@@ -117,12 +118,15 @@ public class RubyRepeatedField extends RubyObject {
* call-seq: * call-seq:
* RepeatedField.[](index) => value * RepeatedField.[](index) => value
* *
* Accesses the element at the given index. Throws an exception on out-of-bounds * Accesses the element at the given index. Returns nil on out-of-bounds
* errors.
*/ */
@JRubyMethod(name = "[]") @JRubyMethod(name = "[]")
public IRubyObject index(ThreadContext context, IRubyObject index) { public IRubyObject index(ThreadContext context, IRubyObject index) {
return this.storage.eltInternal(RubyNumeric.num2int(index)); int arrIndex = normalizeArrayIndex(index);
if (arrIndex < 0 || arrIndex >= this.storage.size()) {
return context.runtime.getNil();
}
return this.storage.eltInternal(arrIndex);
} }
/* /*
...@@ -134,8 +138,7 @@ public class RubyRepeatedField extends RubyObject { ...@@ -134,8 +138,7 @@ public class RubyRepeatedField extends RubyObject {
@JRubyMethod(rest = true) @JRubyMethod(rest = true)
public IRubyObject insert(ThreadContext context, IRubyObject[] args) { public IRubyObject insert(ThreadContext context, IRubyObject[] args) {
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
Utils.checkType(context, fieldType, args[i], (RubyModule) typeClass); push(context, args[i]);
this.storage.add(args[i]);
} }
return context.runtime.getNil(); return context.runtime.getNil();
} }
...@@ -385,6 +388,15 @@ public class RubyRepeatedField extends RubyObject { ...@@ -385,6 +388,15 @@ public class RubyRepeatedField extends RubyObject {
} }
} }
private int normalizeArrayIndex(IRubyObject index) {
int arrIndex = RubyNumeric.num2int(index);
int arrSize = this.storage.size();
if (arrIndex < 0 && arrSize > 0) {
arrIndex = arrSize + arrIndex;
}
return arrIndex;
}
private RubyArray storage; private RubyArray storage;
private Descriptors.FieldDescriptor.Type fieldType; private Descriptors.FieldDescriptor.Type fieldType;
private IRubyObject typeClass; private IRubyObject typeClass;
......
...@@ -314,6 +314,17 @@ module BasicTest ...@@ -314,6 +314,17 @@ module BasicTest
assert l4 == [0, 0, 0, 0, 0, 42, 100, 101, 102] assert l4 == [0, 0, 0, 0, 0, 42, 100, 101, 102]
end end
def test_parent_rptfield
#make sure we set the RepeatedField and can add to it
m = TestMessage.new
assert m.repeated_string == []
m.repeated_string << 'ok'
m.repeated_string.push('ok2')
assert m.repeated_string == ['ok', 'ok2']
m.repeated_string += ['ok3']
assert m.repeated_string == ['ok', 'ok2', 'ok3']
end
def test_rptfield_msg def test_rptfield_msg
l = Google::Protobuf::RepeatedField.new(:message, TestMessage) l = Google::Protobuf::RepeatedField.new(:message, TestMessage)
l.push TestMessage.new l.push TestMessage.new
...@@ -383,10 +394,31 @@ module BasicTest ...@@ -383,10 +394,31 @@ module BasicTest
length_methods.each do |lm| length_methods.each do |lm|
assert l.send(lm) == 0 assert l.send(lm) == 0
end end
# out of bounds returns a nil
assert l[0] == nil
assert l[1] == nil
assert l[-1] == nil
l.push 4 l.push 4
length_methods.each do |lm| length_methods.each do |lm|
assert l.send(lm) == 1 assert l.send(lm) == 1
end
assert l[0] == 4
assert l[1] == nil
assert l[-1] == 4
assert l[-2] == nil
l.push 2
length_methods.each do |lm|
assert l.send(lm) == 2
end end
assert l[0] == 4
assert l[1] == 2
assert l[2] == nil
assert l[-1] == 2
assert l[-2] == 4
assert l[-3] == nil
#adding out of scope will backfill with empty objects
end end
def test_map_basic def test_map_basic
...@@ -724,9 +756,12 @@ module BasicTest ...@@ -724,9 +756,12 @@ module BasicTest
m = TestMessage.new m = TestMessage.new
m.optional_string = "hello" m.optional_string = "hello"
m.optional_int32 = 42 m.optional_int32 = 42
m.repeated_msg.push TestMessage2.new(:foo => 100) tm1 = TestMessage2.new(:foo => 100)
m.repeated_msg.push TestMessage2.new(:foo => 200) tm2 = TestMessage2.new(:foo => 200)
m.repeated_msg.push tm1
assert m.repeated_msg[-1] == tm1
m.repeated_msg.push tm2
assert m.repeated_msg[-1] == tm2
m2 = m.dup m2 = m.dup
assert m == m2 assert m == m2
m.optional_int32 += 1 m.optional_int32 += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment