diff --git a/src/stringio.c b/src/stringio.c index 5560e00..353aacd 100644 --- a/src/stringio.c +++ b/src/stringio.c @@ -468,6 +468,36 @@ stringio_getc(mrb_state *mrb, mrb_value self) return ret; } +static mrb_value +stringio_ungetc(mrb_state *mrb, mrb_value self) +{ + struct StringIO *ptr = StringIO(self); + mrb_value string = stringio_iv_get("@string"); + mrb_int in_len; + const char* in_data; + mrb_get_args(mrb, "s", &in_data, &in_len); + + mrb_int flags = mrb_fixnum(stringio_iv_get("@flags")); + if ((flags & FMODE_READABLE) == 0) { + mrb_raise(mrb, E_IOERROR, "not opened for reading"); + } + + if (in_len == 0) { + return mrb_nil_value(); + } + + if (in_len != 1) { + mrb_raisef(mrb, E_IOERROR, "mruby-ext: ungetc is only supported for single byte string. len: %S", mrb_fixnum_value(in_len)); + } + if (ptr->pos <= 0) { + mrb_raise(mrb, E_IOERROR, "mruby-ext: ungetc not supported in beginning of stream"); + } + ptr->pos -= 1; + + mrb_assert(RSTRING_PTR(string)[ptr->pos] == *in_data); + return mrb_nil_value(); +} + static mrb_value stringio_gets(mrb_state *mrb, mrb_value self) { @@ -653,6 +683,7 @@ mrb_mruby_stringio_gem_init(mrb_state* mrb) mrb_define_alias(mrb, stringio, "syswrite", "write"); mrb_define_method(mrb, stringio, "getc", stringio_getc, MRB_ARGS_ANY()); mrb_define_method(mrb, stringio, "gets", stringio_gets, MRB_ARGS_ANY()); + mrb_define_method(mrb, stringio, "ungetc", stringio_ungetc, MRB_ARGS_REQ(1)); mrb_define_method(mrb, stringio, "seek", stringio_seek, MRB_ARGS_ANY()); mrb_define_method(mrb, stringio, "size", stringio_size, MRB_ARGS_NONE()); mrb_define_alias(mrb, stringio, "length", "size"); diff --git a/test/stringio.rb b/test/stringio.rb index 26c1497..7227289 100644 --- a/test/stringio.rb +++ b/test/stringio.rb @@ -206,6 +206,19 @@ def o.to_s; "baz"; end assert_equal nil, strio.getc end +assert 'StringIO#ungetc' do + strio = StringIO.new("abc") + strio.ungetc '' + assert_equal 'abc', strio.string + assert_raise(IOError) { strio.ungetc 'a' } + assert_equal 'a', strio.getc + assert_nil strio.ungetc 'a' + assert_equal 'a', strio.getc + + strio = StringIO.new("abc", 'w') + assert_raise(IOError) { strio.ungetc 'a' } +end + assert 'StringIO#gets' do io = StringIO.new("this>is>an>example") assert_equal "this>", io.gets(">")