diff --git a/adapter/dynamodb_transcoder.go b/adapter/dynamodb_transcoder.go index 3b6734a..06ee0ba 100644 --- a/adapter/dynamodb_transcoder.go +++ b/adapter/dynamodb_transcoder.go @@ -4,6 +4,7 @@ import ( "encoding/json" "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" "github.com/cockroachdb/errors" ) @@ -14,7 +15,8 @@ func newDynamoDBTranscoder() *dynamodbTranscoder { // create new transcoder } type attributeValue struct { - S string `json:"S"` + S string `json:"S,omitempty"` + L []attributeValue `json:"L,omitempty"` } type putItemInput struct { @@ -43,16 +45,8 @@ func (t *dynamodbTranscoder) PutItemToRequest(b []byte) (*kv.OperationGroup[kv.O if !ok { return nil, errors.New("missing value attribute") } - return &kv.OperationGroup[kv.OP]{ - IsTxn: false, - Elems: []*kv.Elem[kv.OP]{ - { - Op: kv.Put, - Key: []byte(keyAttr.S), - Value: []byte(valAttr.S), - }, - }, - }, nil + + return t.valueAttrToOps([]byte(keyAttr.S), valAttr) } func (t *dynamodbTranscoder) TransactWriteItemsToRequest(b []byte) (*kv.OperationGroup[kv.OP], error) { @@ -74,11 +68,12 @@ func (t *dynamodbTranscoder) TransactWriteItemsToRequest(b []byte) (*kv.Operatio if !ok { return nil, errors.New("missing value attribute") } - elems = append(elems, &kv.Elem[kv.OP]{ - Op: kv.Put, - Key: []byte(keyAttr.S), - Value: []byte(valAttr.S), - }) + + ops, err := t.valueAttrToOps([]byte(keyAttr.S), valAttr) + if err != nil { + return nil, err + } + elems = append(elems, ops.Elems...) } return &kv.OperationGroup[kv.OP]{ @@ -86,3 +81,45 @@ func (t *dynamodbTranscoder) TransactWriteItemsToRequest(b []byte) (*kv.Operatio Elems: elems, }, nil } + +func (t *dynamodbTranscoder) valueAttrToOps(key []byte, val attributeValue) (*kv.OperationGroup[kv.OP], error) { + // List handling: only lists of scalar strings are supported. + if len(val.L) > 0 { + var elems []*kv.Elem[kv.OP] + for i, item := range val.L { + if len(item.L) > 0 { + return nil, errors.New("nested lists are not supported") + } + elems = append(elems, &kv.Elem[kv.OP]{ + Op: kv.Put, + Key: store.ListItemKey(key, int64(i)), + Value: []byte(item.S), + }) + } + meta := store.ListMeta{ + Head: 0, + Tail: int64(len(val.L)), + Len: int64(len(val.L)), + } + b, err := store.MarshalListMeta(meta) + if err != nil { + return nil, errors.WithStack(err) + } + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: store.ListMetaKey(key), Value: b}) + + return &kv.OperationGroup[kv.OP]{IsTxn: true, Elems: elems}, nil + } + + // Default: simple string (allow empty string). Reject only when both S and L are absent. + if val.S == "" && len(val.L) == 0 { + return nil, errors.New("unsupported attribute type (only S or L of S)") + } + return &kv.OperationGroup[kv.OP]{ + IsTxn: false, + Elems: []*kv.Elem[kv.OP]{{ + Op: kv.Put, + Key: key, + Value: []byte(val.S), + }}, + }, nil +} diff --git a/adapter/redis.go b/adapter/redis.go index c71de71..b48ee76 100644 --- a/adapter/redis.go +++ b/adapter/redis.go @@ -3,7 +3,6 @@ package adapter import ( "bytes" "context" - "encoding/json" "math" "net" "sort" @@ -41,6 +40,7 @@ type RedisServer struct { store store.ScanStore coordinator kv.Coordinator redisTranscoder *redisTranscoder + listStore *store.ListStore // TODO manage membership from raft log leaderRedis map[raft.ServerAddress]string @@ -72,12 +72,15 @@ type redisResult struct { err error } +func store2list(st store.ScanStore) *store.ListStore { return store.NewListStore(st) } + func NewRedisServer(listen net.Listener, store store.ScanStore, coordinate *kv.Coordinate, leaderRedis map[raft.ServerAddress]string) *RedisServer { r := &RedisServer{ listen: listen, store: store, coordinator: coordinate, redisTranscoder: newRedisTranscoder(), + listStore: store2list(store), leaderRedis: leaderRedis, } @@ -172,6 +175,17 @@ func (r *RedisServer) ping(conn redcon.Conn, _ redcon.Command) { } func (r *RedisServer) set(conn redcon.Conn, cmd redcon.Command) { + // Prevent overwriting list keys with string values without cleanup. + isList, err := r.isListKey(context.Background(), cmd.Args[1]) + if err != nil { + conn.WriteError(err.Error()) + return + } + if isList { + conn.WriteError("WRONGTYPE Operation against a key holding the wrong kind of value") + return + } + res, err := r.redisTranscoder.SetToRequest(cmd.Args[1], cmd.Args[2]) if err != nil { conn.WriteError(err.Error()) @@ -188,6 +202,14 @@ func (r *RedisServer) set(conn redcon.Conn, cmd redcon.Command) { } func (r *RedisServer) get(conn redcon.Conn, cmd redcon.Command) { + if ok, err := r.isListKey(context.Background(), cmd.Args[1]); err != nil { + conn.WriteError(err.Error()) + return + } else if ok { + conn.WriteError("WRONGTYPE Operation against a key holding the wrong kind of value") + return + } + if r.coordinator.IsLeader() { v, err := r.store.Get(context.Background(), cmd.Args[1]) if err != nil { @@ -218,6 +240,18 @@ func (r *RedisServer) get(conn redcon.Conn, cmd redcon.Command) { } func (r *RedisServer) del(conn redcon.Conn, cmd redcon.Command) { + if ok, err := r.isListKey(context.Background(), cmd.Args[1]); err != nil { + conn.WriteError(err.Error()) + return + } else if ok { + if err := r.deleteList(context.Background(), cmd.Args[1]); err != nil { + conn.WriteError(err.Error()) + return + } + conn.WriteInt(1) + return + } + res, err := r.redisTranscoder.DeleteToRequest(cmd.Args[1]) if err != nil { conn.WriteError(err.Error()) @@ -234,6 +268,14 @@ func (r *RedisServer) del(conn redcon.Conn, cmd redcon.Command) { } func (r *RedisServer) exists(conn redcon.Conn, cmd redcon.Command) { + if ok, err := r.isListKey(context.Background(), cmd.Args[1]); err != nil { + conn.WriteError(err.Error()) + return + } else if ok { + conn.WriteInt(1) + return + } + ok, err := r.store.Exists(context.Background(), cmd.Args[1]) if err != nil { conn.WriteError(err.Error()) @@ -276,39 +318,61 @@ func (r *RedisServer) keys(conn redcon.Conn, cmd redcon.Command) { } func (r *RedisServer) localKeys(pattern []byte) ([][]byte, error) { - // If an asterisk (*) is not included, the match will be exact, - // so check if the key exists. if !bytes.Contains(pattern, []byte("*")) { - res, err := r.store.Exists(context.Background(), pattern) - if err != nil { - return nil, errors.WithStack(err) - } - if res { - return [][]byte{bytes.Clone(pattern)}, nil - } - return [][]byte{}, nil + return r.localKeysExact(pattern) } + return r.localKeysPattern(pattern) +} - var start []byte - switch { - case bytes.Equal(pattern, []byte("*")): - start = nil - default: - start = bytes.ReplaceAll(pattern, []byte("*"), nil) +func (r *RedisServer) localKeysExact(pattern []byte) ([][]byte, error) { + res, err := r.store.Exists(context.Background(), pattern) + if err != nil { + return nil, errors.WithStack(err) } + if res { + return [][]byte{bytes.Clone(pattern)}, nil + } + return [][]byte{}, nil +} + +func (r *RedisServer) localKeysPattern(pattern []byte) ([][]byte, error) { + start := r.patternStart(pattern) keys, err := r.store.Scan(context.Background(), start, nil, math.MaxInt) if err != nil { return nil, errors.WithStack(err) } - out := make([][]byte, 0, len(keys)) - for _, kvPair := range keys { - out = append(out, kvPair.Key) + keyset := r.collectUserKeys(keys) + + out := make([][]byte, 0, len(keyset)) + for _, v := range keyset { + out = append(out, v) } return out, nil } +func (r *RedisServer) patternStart(pattern []byte) []byte { + if bytes.Equal(pattern, []byte("*")) { + return nil + } + return bytes.ReplaceAll(pattern, []byte("*"), nil) +} + +func (r *RedisServer) collectUserKeys(kvs []*store.KVPair) map[string][]byte { + keyset := map[string][]byte{} + for _, kvPair := range kvs { + if store.IsListMetaKey(kvPair.Key) || store.IsListItemKey(kvPair.Key) { + if userKey := store.ExtractListUserKey(kvPair.Key); userKey != nil { + keyset[string(userKey)] = userKey + } + continue + } + keyset[string(kvPair.Key)] = kvPair.Key + } + return keyset +} + func (r *RedisServer) proxyKeys(pattern []byte) ([]string, error) { leader := r.coordinator.RaftLeader() if leader == "" { @@ -374,16 +438,22 @@ func (r *RedisServer) exec(conn redcon.Conn, _ redcon.Command) { type txnValue struct { raw []byte - list []string - isList bool deleted bool dirty bool loaded bool } type txnContext struct { - server *RedisServer - working map[string]*txnValue + server *RedisServer + working map[string]*txnValue + listStates map[string]*listTxnState +} + +type listTxnState struct { + meta store.ListMeta + metaExists bool + appends [][]byte + deleted bool } func (t *txnContext) load(key []byte) (*txnValue, error) { @@ -402,6 +472,30 @@ func (t *txnContext) load(key []byte) (*txnValue, error) { return tv, nil } +func (t *txnContext) loadListState(key []byte) (*listTxnState, error) { + k := string(key) + if st, ok := t.listStates[k]; ok { + return st, nil + } + + meta, exists, err := t.server.loadListMeta(context.Background(), key) + if err != nil { + return nil, err + } + + st := &listTxnState{ + meta: meta, + metaExists: exists, + appends: [][]byte{}, + } + t.listStates[k] = st + return st, nil +} + +func (t *txnContext) listLength(st *listTxnState) int64 { + return st.meta.Len + int64(len(st.appends)) +} + func (t *txnContext) apply(cmd redcon.Command) (redisResult, error) { switch strings.ToUpper(string(cmd.Args[0])) { case "SET": @@ -422,18 +516,36 @@ func (t *txnContext) apply(cmd redcon.Command) (redisResult, error) { } func (t *txnContext) applySet(cmd redcon.Command) (redisResult, error) { + if isList, err := t.server.isListKey(context.Background(), cmd.Args[1]); err != nil { + return redisResult{}, err + } else if isList { + return redisResult{typ: resultError, err: errors.New("WRONGTYPE Operation against a key holding the wrong kind of value")}, nil + } + tv, err := t.load(cmd.Args[1]) if err != nil { return redisResult{}, err } tv.raw = cmd.Args[2] - tv.isList = false tv.deleted = false tv.dirty = true return redisResult{typ: resultString, str: "OK"}, nil } func (t *txnContext) applyDel(cmd redcon.Command) (redisResult, error) { + // handle list delete separately + if isList, err := t.server.isListKey(context.Background(), cmd.Args[1]); err != nil { + return redisResult{}, err + } else if isList { + st, err := t.loadListState(cmd.Args[1]) + if err != nil { + return redisResult{}, err + } + st.deleted = true + st.appends = nil + return redisResult{typ: resultInt, integer: 1}, nil + } + tv, err := t.load(cmd.Args[1]) if err != nil { return redisResult{}, err @@ -444,6 +556,12 @@ func (t *txnContext) applyDel(cmd redcon.Command) (redisResult, error) { } func (t *txnContext) applyGet(cmd redcon.Command) (redisResult, error) { + if isList, err := t.server.isListKey(context.Background(), cmd.Args[1]); err != nil { + return redisResult{}, err + } else if isList { + return redisResult{typ: resultError, err: errors.New("WRONGTYPE Operation against a key holding the wrong kind of value")}, nil + } + tv, err := t.load(cmd.Args[1]) if err != nil { return redisResult{}, err @@ -455,6 +573,12 @@ func (t *txnContext) applyGet(cmd redcon.Command) (redisResult, error) { } func (t *txnContext) applyExists(cmd redcon.Command) (redisResult, error) { + if isList, err := t.server.isListKey(context.Background(), cmd.Args[1]); err != nil { + return redisResult{}, err + } else if isList { + return redisResult{typ: resultInt, integer: 1}, nil + } + tv, err := t.load(cmd.Args[1]) if err != nil { return redisResult{}, err @@ -465,63 +589,101 @@ func (t *txnContext) applyExists(cmd redcon.Command) (redisResult, error) { return redisResult{typ: resultInt, integer: 1}, nil } -func (t *txnContext) ensureList(tv *txnValue) error { - if tv.isList { - return nil - } - list, err := decodeList(tv.raw) - if err != nil { - return err - } - tv.list = list - tv.isList = true - return nil -} - func (t *txnContext) applyRPush(cmd redcon.Command) (redisResult, error) { - tv, err := t.load(cmd.Args[1]) + st, err := t.loadListState(cmd.Args[1]) if err != nil { return redisResult{}, err } - if err := t.ensureList(tv); err != nil { - return redisResult{}, err - } + for _, v := range cmd.Args[2:] { - tv.list = append(tv.list, string(v)) + st.appends = append(st.appends, bytes.Clone(v)) } - tv.dirty = true - tv.deleted = false - return redisResult{typ: resultInt, integer: int64(len(tv.list))}, nil + + return redisResult{typ: resultInt, integer: t.listLength(st)}, nil } func (t *txnContext) applyLRange(cmd redcon.Command) (redisResult, error) { - tv, err := t.load(cmd.Args[1]) + st, err := t.loadListState(cmd.Args[1]) if err != nil { return redisResult{}, err } - if err := t.ensureList(tv); err != nil { + + s, e, err := parseRangeBounds(cmd.Args[2], cmd.Args[3], int(t.listLength(st))) + if err != nil { return redisResult{}, err } - start, err := strconv.Atoi(string(cmd.Args[2])) + if e < s { + return redisResult{typ: resultArray, arr: []string{}}, nil + } + + out, err := t.listRangeValues(cmd.Args[1], st, s, e) if err != nil { - return redisResult{}, errors.WithStack(err) + return redisResult{}, err } - end, err := strconv.Atoi(string(cmd.Args[3])) + + return redisResult{typ: resultArray, arr: out}, nil +} + +func parseRangeBounds(startRaw, endRaw []byte, total int) (int, int, error) { + start, err := parseInt(startRaw) if err != nil { - return redisResult{}, errors.WithStack(err) + return 0, 0, err } - s, e := clampRange(start, end, len(tv.list)) - if e < s { - return redisResult{typ: resultArray, arr: []string{}}, nil + end, err := parseInt(endRaw) + if err != nil { + return 0, 0, err } - return redisResult{typ: resultArray, arr: tv.list[s : e+1]}, nil + s, e := clampRange(start, end, total) + return s, e, nil +} + +func (t *txnContext) listRangeValues(key []byte, st *listTxnState, s, e int) ([]string, error) { + persistedLen := int(st.meta.Len) + + switch { + case e < persistedLen: + return t.server.fetchListRange(context.Background(), key, st.meta, int64(s), int64(e)) + case s >= persistedLen: + return appendValues(st.appends, s-persistedLen, e-persistedLen), nil + default: + head, err := t.server.fetchListRange(context.Background(), key, st.meta, int64(s), int64(persistedLen-1)) + if err != nil { + return nil, err + } + tail := appendValues(st.appends, 0, e-persistedLen) + return append(head, tail...), nil + } +} + +func appendValues(buf [][]byte, start, end int) []string { + out := make([]string, 0, end-start+1) + for i := start; i <= end; i++ { + out = append(out, string(buf[i])) + } + return out } func (t *txnContext) commit() error { - if len(t.working) == 0 { + elems := t.buildKeyElems() + + listElems, err := t.buildListElems() + if err != nil { + return err + } + + elems = append(elems, listElems...) + if len(elems) == 0 { return nil } + group := &kv.OperationGroup[kv.OP]{IsTxn: true, Elems: elems} + if _, err := t.server.coordinator.Dispatch(group); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (t *txnContext) buildKeyElems() []*kv.Elem[kv.OP] { keys := make([]string, 0, len(t.working)) for k := range t.working { keys = append(keys, k) @@ -539,34 +701,60 @@ func (t *txnContext) commit() error { elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: key}) continue } - var val []byte - if tv.isList { - enc, err := encodeList(tv.list) - if err != nil { - return err - } - val = enc - } else { - val = tv.raw - } - elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: key, Value: val}) + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: key, Value: tv.raw}) } + return elems +} - if len(elems) == 0 { - return nil +func (t *txnContext) buildListElems() ([]*kv.Elem[kv.OP], error) { + listKeys := make([]string, 0, len(t.listStates)) + for k := range t.listStates { + listKeys = append(listKeys, k) } + sort.Strings(listKeys) - group := &kv.OperationGroup[kv.OP]{IsTxn: true, Elems: elems} - if _, err := t.server.coordinator.Dispatch(group); err != nil { - return errors.WithStack(err) + var elems []*kv.Elem[kv.OP] + for _, k := range listKeys { + st := t.listStates[k] + userKey := []byte(k) + + if st.deleted { + // delete all persisted list items + for seq := st.meta.Head; seq < st.meta.Tail; seq++ { + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: listItemKey(userKey, seq)}) + } + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: listMetaKey(userKey)}) + continue + } + if len(st.appends) == 0 { + continue + } + + startSeq := st.meta.Head + st.meta.Len + for i, v := range st.appends { + elems = append(elems, &kv.Elem[kv.OP]{ + Op: kv.Put, + Key: listItemKey(userKey, startSeq+int64(i)), + Value: v, + }) + } + + st.meta.Len += int64(len(st.appends)) + st.meta.Tail = st.meta.Head + st.meta.Len + metaBytes, err := store.MarshalListMeta(st.meta) + if err != nil { + return nil, errors.WithStack(err) + } + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: listMetaKey(userKey), Value: metaBytes}) } - return nil + return elems, nil } func (r *RedisServer) runTransaction(queue []redcon.Command) ([]redisResult, error) { ctx := &txnContext{ - server: r, - working: map[string]*txnValue{}, + server: r, + working: map[string]*txnValue{}, + listStates: map[string]*listTxnState{}, } results := make([]redisResult, 0, len(queue)) @@ -610,21 +798,14 @@ func (r *RedisServer) writeResults(conn redcon.Conn, results []redisResult) { } } -// list helpers -func decodeList(b []byte) ([]string, error) { - if b == nil { - return []string{}, nil - } - var out []string - if err := json.Unmarshal(b, &out); err != nil { - return nil, errors.WithStack(err) - } - return out, nil +// --- list helpers ---------------------------------------------------- + +func listMetaKey(userKey []byte) []byte { + return store.ListMetaKey(userKey) } -func encodeList(list []string) ([]byte, error) { - b, err := json.Marshal(list) - return b, errors.WithStack(err) +func listItemKey(userKey []byte, seq int64) []byte { + return store.ListItemKey(userKey, seq) } func clampRange(start, end, length int) (int, int) { @@ -646,15 +827,130 @@ func clampRange(start, end, length int) (int, int) { return start, end } -func (r *RedisServer) rangeList(key []byte, startRaw, endRaw []byte) ([]string, error) { - val, err := r.getValue(key) - if err != nil && !errors.Is(err, store.ErrKeyNotFound) { - return nil, errors.WithStack(err) +func (r *RedisServer) loadListMeta(ctx context.Context, key []byte) (store.ListMeta, bool, error) { + meta, exists, err := r.listStore.LoadMeta(ctx, key) + return meta, exists, errors.WithStack(err) +} + +func (r *RedisServer) isListKey(ctx context.Context, key []byte) (bool, error) { + isList, err := r.listStore.IsList(ctx, key) + return isList, errors.WithStack(err) +} + +func (r *RedisServer) buildRPushOps(meta store.ListMeta, key []byte, values [][]byte) ([]*kv.Elem[kv.OP], store.ListMeta, error) { + if len(values) == 0 { + return nil, meta, nil } - list, err := decodeList(val) + + elems := make([]*kv.Elem[kv.OP], 0, len(values)+1) + seq := meta.Head + meta.Len + for _, v := range values { + vCopy := bytes.Clone(v) + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: listItemKey(key, seq), Value: vCopy}) + seq++ + } + + meta.Len += int64(len(values)) + meta.Tail = meta.Head + meta.Len + + b, err := store.MarshalListMeta(meta) + if err != nil { + return nil, meta, errors.WithStack(err) + } + + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: listMetaKey(key), Value: b}) + return elems, meta, nil +} + +func (r *RedisServer) listRPush(ctx context.Context, key []byte, values [][]byte) (int64, error) { + meta, _, err := r.loadListMeta(ctx, key) + if err != nil { + return 0, err + } + + ops, newMeta, err := r.buildRPushOps(meta, key, values) + if err != nil { + return 0, err + } + if len(ops) == 0 { + return newMeta.Len, nil + } + + group := &kv.OperationGroup[kv.OP]{IsTxn: true, Elems: ops} + if _, err := r.coordinator.Dispatch(group); err != nil { + return 0, errors.WithStack(err) + } + return newMeta.Len, nil +} + +func (r *RedisServer) deleteList(ctx context.Context, key []byte) error { + meta, exists, err := r.loadListMeta(ctx, key) + if err != nil { + return err + } + if !exists { + return nil + } + + start := listItemKey(key, math.MinInt64) + end := listItemKey(key, math.MaxInt64) + + kvs, err := r.store.Scan(ctx, start, end, math.MaxInt) + if err != nil { + return errors.WithStack(err) + } + + ops := make([]*kv.Elem[kv.OP], 0, len(kvs)+1) + for _, kvp := range kvs { + ops = append(ops, &kv.Elem[kv.OP]{Op: kv.Del, Key: kvp.Key}) + } + // delete meta last + ops = append(ops, &kv.Elem[kv.OP]{Op: kv.Del, Key: listMetaKey(key)}) + + // ensure meta bounds consistent even if scan missed (in case of empty list) + _ = meta + + group := &kv.OperationGroup[kv.OP]{IsTxn: true, Elems: ops} + _, err = r.coordinator.Dispatch(group) + return errors.WithStack(err) +} + +func (r *RedisServer) fetchListRange(ctx context.Context, key []byte, meta store.ListMeta, startIdx, endIdx int64) ([]string, error) { + if endIdx < startIdx { + return []string{}, nil + } + + startSeq := meta.Head + startIdx + endSeq := meta.Head + endIdx + + startKey := listItemKey(key, startSeq) + endKey := listItemKey(key, endSeq+1) // exclusive + + kvs, err := r.store.Scan(ctx, startKey, endKey, int(endIdx-startIdx+1)) if err != nil { return nil, errors.WithStack(err) } + + out := make([]string, 0, len(kvs)) + for _, kvp := range kvs { + out = append(out, string(kvp.Value)) + } + return out, nil +} + +func (r *RedisServer) rangeList(key []byte, startRaw, endRaw []byte) ([]string, error) { + if !r.coordinator.IsLeader() { + return r.proxyLRange(key, startRaw, endRaw) + } + + meta, exists, err := r.loadListMeta(context.Background(), key) + if err != nil { + return nil, err + } + if !exists || meta.Len == 0 { + return []string{}, nil + } + start, err := strconv.Atoi(string(startRaw)) if err != nil { return nil, errors.WithStack(err) @@ -663,11 +959,66 @@ func (r *RedisServer) rangeList(key []byte, startRaw, endRaw []byte) ([]string, if err != nil { return nil, errors.WithStack(err) } - s, e := clampRange(start, end, len(list)) + + s, e := clampRange(start, end, int(meta.Len)) if e < s { return []string{}, nil } - return list[s : e+1], nil + + return r.fetchListRange(context.Background(), key, meta, int64(s), int64(e)) +} + +func (r *RedisServer) proxyLRange(key []byte, startRaw, endRaw []byte) ([]string, error) { + leader := r.coordinator.RaftLeader() + if leader == "" { + return nil, ErrLeaderNotFound + } + leaderAddr, ok := r.leaderRedis[leader] + if !ok || leaderAddr == "" { + return nil, errors.WithStack(errors.Newf("leader redis address unknown for %s", leader)) + } + + cli := redis.NewClient(&redis.Options{Addr: leaderAddr}) + defer func() { _ = cli.Close() }() + + start, err := parseInt(startRaw) + if err != nil { + return nil, err + } + end, err := parseInt(endRaw) + if err != nil { + return nil, err + } + + res, err := cli.LRange(context.Background(), string(key), int64(start), int64(end)).Result() + return res, errors.WithStack(err) +} + +func (r *RedisServer) proxyRPush(key []byte, values [][]byte) (int64, error) { + leader := r.coordinator.RaftLeader() + if leader == "" { + return 0, ErrLeaderNotFound + } + leaderAddr, ok := r.leaderRedis[leader] + if !ok || leaderAddr == "" { + return 0, errors.WithStack(errors.Newf("leader redis address unknown for %s", leader)) + } + + cli := redis.NewClient(&redis.Options{Addr: leaderAddr}) + defer func() { _ = cli.Close() }() + + args := make([]interface{}, 0, len(values)) + for _, v := range values { + args = append(args, string(v)) + } + + res, err := cli.RPush(context.Background(), string(key), args...).Result() + return res, errors.WithStack(err) +} + +func parseInt(b []byte) (int, error) { + i, err := strconv.Atoi(string(b)) + return i, errors.WithStack(err) } // tryLeaderGet proxies a GET to the current Raft leader, returning the value and @@ -705,25 +1056,21 @@ func (r *RedisServer) getValue(key []byte) ([]byte, error) { } func (r *RedisServer) rpush(conn redcon.Conn, cmd redcon.Command) { - results, err := r.runTransaction([]redcon.Command{cmd}) + ctx := context.Background() + + var length int64 + var err error + if r.coordinator.IsLeader() { + length, err = r.listRPush(ctx, cmd.Args[1], cmd.Args[2:]) + } else { + length, err = r.proxyRPush(cmd.Args[1], cmd.Args[2:]) + } + if err != nil { conn.WriteError(err.Error()) return } - if len(results) != 1 { - conn.WriteError("ERR internal error: rpush should have one result") - return - } - res := results[0] - if res.err != nil { - conn.WriteError(res.err.Error()) - return - } - if res.typ != resultInt { - conn.WriteError("ERR internal error: rpush result should be an integer") - return - } - conn.WriteInt64(res.integer) + conn.WriteInt64(length) } func (r *RedisServer) lrange(conn redcon.Conn, cmd redcon.Command) { diff --git a/adapter/test_util.go b/adapter/test_util.go index fd38dad..5d023a4 100644 --- a/adapter/test_util.go +++ b/adapter/test_util.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/raft" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -135,8 +136,7 @@ func createNode(t *testing.T, n int) ([]Node, []string, []string) { ctx := context.Background() ports := assignPorts(n) - cfg := buildRaftConfig(n, ports) - nodes, grpcAdders, redisAdders := setupNodes(t, ctx, n, ports, cfg) + nodes, grpcAdders, redisAdders, cfg := setupNodes(t, ctx, n, ports) waitForNodeListeners(t, ctx, nodes, waitTimeout, waitInterval) waitForConfigReplication(t, cfg, nodes, waitTimeout, waitInterval) @@ -145,6 +145,47 @@ func createNode(t *testing.T, n int) ([]Node, []string, []string) { return nodes, grpcAdders, redisAdders } +type listeners struct { + grpc net.Listener + redis net.Listener + dynamo net.Listener +} + +func bindListeners(ctx context.Context, lc *net.ListenConfig, port portsAdress) (portsAdress, listeners, bool, error) { + grpcSock, err := lc.Listen(ctx, "tcp", port.grpcAddress) + if err != nil { + if errors.Is(err, unix.EADDRINUSE) { + return port, listeners{}, true, nil + } + return port, listeners{}, false, errors.WithStack(err) + } + + redisSock, err := lc.Listen(ctx, "tcp", port.redisAddress) + if err != nil { + _ = grpcSock.Close() + if errors.Is(err, unix.EADDRINUSE) { + return port, listeners{}, true, nil + } + return port, listeners{}, false, errors.WithStack(err) + } + + dynamoSock, err := lc.Listen(ctx, "tcp", port.dynamoAddress) + if err != nil { + _ = grpcSock.Close() + _ = redisSock.Close() + if errors.Is(err, unix.EADDRINUSE) { + return port, listeners{}, true, nil + } + return port, listeners{}, false, errors.WithStack(err) + } + + return port, listeners{ + grpc: grpcSock, + redis: redisSock, + dynamo: dynamoSock, + }, false, nil +} + func waitForNodeListeners(t *testing.T, ctx context.Context, nodes []Node, waitTimeout, waitInterval time.Duration) { t.Helper() d := &net.Dialer{Timeout: time.Second} @@ -251,24 +292,48 @@ func buildRaftConfig(n int, ports []portsAdress) raft.Configuration { const leaderElectionTimeout = 0 * time.Second -func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, cfg raft.Configuration) ([]Node, []string, []string) { +func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ([]Node, []string, []string, raft.Configuration) { t.Helper() var grpcAdders []string var redisAdders []string var nodes []Node - var lc net.ListenConfig - - leaderRedis := make(map[raft.ServerAddress]string, n) + lc := net.ListenConfig{} + lis := make([]listeners, n) for i := 0; i < n; i++ { - leaderRedis[raft.ServerAddress(ports[i].raftAddress)] = ports[i].redisAddress + var ( + bound portsAdress + l listeners + retry bool + err error + ) + for { + bound, l, retry, err = bindListeners(ctx, &lc, ports[i]) + require.NoError(t, err) + if retry { + ports[i] = portAssigner() + continue + } + ports[i] = bound + lis[i] = l + break + } } + cfg := buildRaftConfig(n, ports) + for i := 0; i < n; i++ { st := store.NewRbMemoryStore() trxSt := store.NewMemoryStoreDefaultTTL() fsm := kv.NewKvFSM(st, trxSt) port := ports[i] + grpcSock := lis[i].grpc + redisSock := lis[i].redis + dynamoSock := lis[i].dynamo + + leaderRedis := map[raft.ServerAddress]string{ + raft.ServerAddress(ports[i].raftAddress): ports[i].redisAddress, + } // リーダーが先に投票を開始させる electionTimeout := leaderElectionTimeout @@ -291,25 +356,18 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, c leaderhealth.Setup(r, s, []string{"Example"}) raftadmin.Register(s, r) - grpcSock, err := lc.Listen(ctx, "tcp", port.grpcAddress) - require.NoError(t, err) - grpcAdders = append(grpcAdders, port.grpcAddress) redisAdders = append(redisAdders, port.redisAddress) go func(srv *grpc.Server, lis net.Listener) { assert.NoError(t, srv.Serve(lis)) }(s, grpcSock) - l, err := lc.Listen(ctx, "tcp", port.redisAddress) - require.NoError(t, err) - rd := NewRedisServer(l, st, coordinator, leaderRedis) + rd := NewRedisServer(redisSock, st, coordinator, leaderRedis) go func(server *RedisServer) { assert.NoError(t, server.Run()) }(rd) - dl, err := lc.Listen(ctx, "tcp", port.dynamoAddress) - assert.NoError(t, err) - ds := NewDynamoDBServer(dl, st, coordinator) + ds := NewDynamoDBServer(dynamoSock, st, coordinator) go func() { assert.NoError(t, ds.Run()) }() @@ -327,7 +385,7 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress, c )) } - return nodes, grpcAdders, redisAdders + return nodes, grpcAdders, redisAdders, cfg } func newRaft(myID string, myAddress string, fsm raft.FSM, bootstrap bool, cfg raft.Configuration, electionTimeout time.Duration) (*raft.Raft, *transport.Manager, error) { diff --git a/kv/coordinator.go b/kv/coordinator.go index 95b9265..0031295 100644 --- a/kv/coordinator.go +++ b/kv/coordinator.go @@ -121,6 +121,8 @@ func (c *Coordinate) toRawRequest(req *Elem[OP]) *pb.Request { panic("unreachable") } +const defaultTxnLockTTLSeconds = uint64(30) + func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { switch req.Op { case Put: @@ -128,6 +130,7 @@ func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { { IsTxn: true, Phase: pb.Phase_PREPARE, + Ts: defaultTxnLockTTLSeconds, Mutations: []*pb.Mutation{ { Key: req.Key, @@ -138,6 +141,7 @@ func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { { IsTxn: true, Phase: pb.Phase_COMMIT, + Ts: defaultTxnLockTTLSeconds, Mutations: []*pb.Mutation{ { Key: req.Key, @@ -152,6 +156,7 @@ func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { { IsTxn: true, Phase: pb.Phase_PREPARE, + Ts: defaultTxnLockTTLSeconds, Mutations: []*pb.Mutation{ { Key: req.Key, @@ -161,6 +166,7 @@ func (c *Coordinate) toTxnRequests(req *Elem[OP]) []*pb.Request { { IsTxn: true, Phase: pb.Phase_COMMIT, + Ts: defaultTxnLockTTLSeconds, Mutations: []*pb.Mutation{ { Key: req.Key, diff --git a/kv/fsm.go b/kv/fsm.go index 485e201..63d34ff 100644 --- a/kv/fsm.go +++ b/kv/fsm.go @@ -209,6 +209,7 @@ func (f *kvFSM) handleCommitRequest(ctx context.Context, r *pb.Request) error { } if !ok { + // Lock already gone: treat as conflict and abort. return errors.WithStack(ErrKeyNotLocked) } diff --git a/kv/shard_router.go b/kv/shard_router.go index 556860d..c3b9bda 100644 --- a/kv/shard_router.go +++ b/kv/shard_router.go @@ -58,7 +58,8 @@ func (s *ShardRouter) process(reqs []*pb.Request, fn func(*routerGroup, []*pb.Re return nil, errors.WithStack(err) } - var max uint64 + var firstErr error + var maxIndex uint64 for gid, rs := range grouped { g, ok := s.getGroup(gid) if !ok { @@ -66,13 +67,19 @@ func (s *ShardRouter) process(reqs []*pb.Request, fn func(*routerGroup, []*pb.Re } r, err := fn(g, rs) if err != nil { - return nil, errors.WithStack(err) + if firstErr == nil { + firstErr = errors.WithStack(err) + } + continue } - if r.CommitIndex > max { - max = r.CommitIndex + if r.CommitIndex > maxIndex { + maxIndex = r.CommitIndex } } - return &TransactionResponse{CommitIndex: max}, nil + if firstErr != nil { + return nil, firstErr + } + return &TransactionResponse{CommitIndex: maxIndex}, nil } func (s *ShardRouter) getGroup(id uint64) (*routerGroup, bool) { diff --git a/store/bolt_store.go b/store/bolt_store.go index fa1b451..d4719a2 100644 --- a/store/bolt_store.go +++ b/store/bolt_store.go @@ -175,6 +175,22 @@ func (t *boltStoreTxn) Exists(_ context.Context, key []byte) (bool, error) { return t.bucket.Get(key) != nil, nil } +func (t *boltStoreTxn) Scan(_ context.Context, start []byte, end []byte, limit int) ([]*KVPair, error) { + if limit <= 0 { + return nil, nil + } + + var res []*KVPair + c := t.bucket.Cursor() + for k, v := c.Seek(start); k != nil && (end == nil || bytes.Compare(k, end) < 0); k, v = c.Next() { + res = append(res, &KVPair{Key: k, Value: v}) + if len(res) >= limit { + break + } + } + return res, nil +} + func (s *boltStore) Txn(ctx context.Context, fn func(ctx context.Context, txn Txn) error) error { btxn, err := s.bbolt.Begin(true) if err != nil { diff --git a/store/list_store.go b/store/list_store.go new file mode 100644 index 0000000..aa21b61 --- /dev/null +++ b/store/list_store.go @@ -0,0 +1,347 @@ +package store + +import ( + "bytes" + "context" + "encoding/binary" + "math" + + "github.com/cockroachdb/errors" +) + +// Wide-column style list storage using per-element keys. +// Item keys: !lst|itm| +// Meta key : !lst|meta| -> [Head(8)][Tail(8)][Len(8)] + +const ( + ListMetaPrefix = "!lst|meta|" + ListItemPrefix = "!lst|itm|" + // limit per scan when deleting to avoid OOM. + deleteBatchSize = 1024 + listMetaBinarySize = 24 + scanAdvanceByte = byte(0x00) +) + +type ListMeta struct { + Head int64 `json:"h"` + Tail int64 `json:"t"` + Len int64 `json:"l"` +} + +// ListStore requires ScanStore to fetch ranges efficiently. +type ListStore struct { + store ScanStore +} + +func NewListStore(base ScanStore) *ListStore { + return &ListStore{store: base} +} + +// IsList reports whether the given key has list metadata. +func (s *ListStore) IsList(ctx context.Context, key []byte) (bool, error) { + _, exists, err := s.LoadMeta(ctx, key) + return exists, err +} + +// PutList replaces the entire list. +func (s *ListStore) PutList(ctx context.Context, key []byte, list []string) error { + meta := ListMeta{Head: 0, Tail: int64(len(list)), Len: int64(len(list))} + metaBytes, err := marshalListMeta(meta) + if err != nil { + return errors.WithStack(err) + } + + return errors.WithStack(s.store.Txn(ctx, func(ctx context.Context, txn Txn) error { + scanTxn, ok := txn.(ScanTxn) + if !ok { + return errors.WithStack(ErrNotSupported) + } + + if err := s.deleteListTxn(ctx, scanTxn, key); err != nil { + return err + } + + for i, v := range list { + if err := scanTxn.Put(ctx, ListItemKey(key, int64(i)), []byte(v)); err != nil { + return errors.WithStack(err) + } + } + if err := scanTxn.Put(ctx, ListMetaKey(key), metaBytes); err != nil { + return errors.WithStack(err) + } + return nil + })) +} + +// GetList returns the whole list. It reconstructs via Scan; avoid for huge lists. +func (s *ListStore) GetList(ctx context.Context, key []byte) ([]string, error) { + meta, exists, err := s.LoadMeta(ctx, key) + if err != nil { + return nil, err + } + if !exists || meta.Len == 0 { + return nil, ErrKeyNotFound + } + return s.Range(ctx, key, 0, int(meta.Len)-1) +} + +// RPush appends values and returns new length. +func (s *ListStore) RPush(ctx context.Context, key []byte, values ...string) (int, error) { + if len(values) == 0 { + return 0, nil + } + + newLen := 0 + + err := s.store.Txn(ctx, func(ctx context.Context, txn Txn) error { + // load meta inside txn for correctness + meta, exists, err := s.loadMetaTxn(ctx, txn, key) + if err != nil { + return err + } + if !exists { + meta = ListMeta{Head: 0, Tail: 0, Len: 0} + } + + startSeq := meta.Head + meta.Len + + for i, v := range values { + seq := startSeq + int64(i) + if err := txn.Put(ctx, ListItemKey(key, seq), []byte(v)); err != nil { + return errors.WithStack(err) + } + } + meta.Len += int64(len(values)) + meta.Tail = meta.Head + meta.Len + metaBytes, err := marshalListMeta(meta) + if err != nil { + return errors.WithStack(err) + } + newLen = int(meta.Len) + return errors.WithStack(txn.Put(ctx, ListMetaKey(key), metaBytes)) + }) + if err != nil { + return 0, errors.WithStack(err) + } + + return newLen, nil +} + +// Range returns elements between start and end (inclusive). +// Negative indexes follow Redis semantics. +func (s *ListStore) Range(ctx context.Context, key []byte, start, end int) ([]string, error) { + meta, exists, err := s.LoadMeta(ctx, key) + if err != nil { + return nil, err + } + if !exists || meta.Len == 0 { + return nil, ErrKeyNotFound + } + + si, ei := clampRange(start, end, int(meta.Len)) + if ei < si { + return []string{}, nil + } + + startSeq := meta.Head + int64(si) + endSeq := meta.Head + int64(ei) + startKey := ListItemKey(key, startSeq) + endKey := ListItemKey(key, endSeq+1) // exclusive + + kvs, err := s.store.Scan(ctx, startKey, endKey, ei-si+1) + if err != nil { + return nil, errors.WithStack(err) + } + + out := make([]string, 0, len(kvs)) + for _, kvp := range kvs { + out = append(out, string(kvp.Value)) + } + return out, nil +} + +// --- helpers --- + +// LoadMeta returns metadata and whether the list exists. +func (s *ListStore) LoadMeta(ctx context.Context, key []byte) (ListMeta, bool, error) { + val, err := s.store.Get(ctx, ListMetaKey(key)) + if err != nil { + if errors.Is(err, ErrKeyNotFound) { + return ListMeta{}, false, nil + } + return ListMeta{}, false, errors.WithStack(err) + } + if len(val) == 0 { + return ListMeta{}, false, nil + } + meta, err := unmarshalListMeta(val) + return meta, err == nil, errors.WithStack(err) +} + +func (s *ListStore) loadMetaTxn(ctx context.Context, txn Txn, key []byte) (ListMeta, bool, error) { + val, err := txn.Get(ctx, ListMetaKey(key)) + if err != nil { + if errors.Is(err, ErrKeyNotFound) { + return ListMeta{}, false, nil + } + return ListMeta{}, false, errors.WithStack(err) + } + if len(val) == 0 { + return ListMeta{}, false, nil + } + meta, err := unmarshalListMeta(val) + return meta, err == nil, errors.WithStack(err) +} + +// deleteListTxn deletes list items and metadata within the provided transaction. +func (s *ListStore) deleteListTxn(ctx context.Context, txn ScanTxn, key []byte) error { + start := ListItemKey(key, mathMinInt64) // inclusive + end := ListItemKey(key, mathMaxInt64) // inclusive sentinel + + for { + kvs, err := txn.Scan(ctx, start, end, deleteBatchSize) + if err != nil && !errors.Is(err, ErrKeyNotFound) { + return errors.WithStack(err) + } + if len(kvs) == 0 { + break + } + + for _, kvp := range kvs { + if err := txn.Delete(ctx, kvp.Key); err != nil { + return errors.WithStack(err) + } + } + + // advance start just after the last processed key to guarantee forward progress + lastKey := kvs[len(kvs)-1].Key + start = append(bytes.Clone(lastKey), scanAdvanceByte) + + if len(kvs) < deleteBatchSize { + break + } + } + + // delete meta last (ignore missing) + if err := txn.Delete(ctx, ListMetaKey(key)); err != nil && !errors.Is(err, ErrKeyNotFound) { + return errors.WithStack(err) + } + return nil +} + +// ListMetaKey builds the metadata key for a user key. +func ListMetaKey(userKey []byte) []byte { + return append([]byte(ListMetaPrefix), userKey...) +} + +// ListItemKey builds the item key for a user key and sequence number. +func ListItemKey(userKey []byte, seq int64) []byte { + // Offset sign bit (seq ^ minInt64) to preserve order, then big-endian encode (8 bytes). + var raw [8]byte + encodeSortableInt64(raw[:], seq) + + buf := make([]byte, 0, len(ListItemPrefix)+len(userKey)+len(raw)) + buf = append(buf, ListItemPrefix...) + buf = append(buf, userKey...) + buf = append(buf, raw[:]...) + return buf +} + +// MarshalListMeta encodes ListMeta into a fixed 24-byte binary format. +func MarshalListMeta(meta ListMeta) ([]byte, error) { return marshalListMeta(meta) } + +// UnmarshalListMeta decodes ListMeta from the fixed 24-byte binary format. +func UnmarshalListMeta(b []byte) (ListMeta, error) { return unmarshalListMeta(b) } + +func marshalListMeta(meta ListMeta) ([]byte, error) { + if meta.Head < 0 || meta.Tail < 0 || meta.Len < 0 { + return nil, errors.WithStack(errors.Newf("list meta contains negative value: head=%d tail=%d len=%d", meta.Head, meta.Tail, meta.Len)) + } + + buf := make([]byte, listMetaBinarySize) + binary.BigEndian.PutUint64(buf[0:8], uint64(meta.Head)) + binary.BigEndian.PutUint64(buf[8:16], uint64(meta.Tail)) + binary.BigEndian.PutUint64(buf[16:24], uint64(meta.Len)) + return buf, nil +} + +func unmarshalListMeta(b []byte) (ListMeta, error) { + if len(b) != listMetaBinarySize { + return ListMeta{}, errors.Wrap(errors.Newf("invalid list meta length: %d", len(b)), "unmarshal list meta") + } + + head := binary.BigEndian.Uint64(b[0:8]) + tail := binary.BigEndian.Uint64(b[8:16]) + length := binary.BigEndian.Uint64(b[16:24]) + + if head > math.MaxInt64 || tail > math.MaxInt64 || length > math.MaxInt64 { + return ListMeta{}, errors.New("list meta value overflows int64") + } + + return ListMeta{ + Head: int64(head), + Tail: int64(tail), + Len: int64(length), + }, nil +} + +// encodeSortableInt64 writes seq with sign bit flipped (seq ^ minInt64) in big-endian order. +const sortableInt64Bytes = 8 + +func encodeSortableInt64(dst []byte, seq int64) { + if len(dst) < sortableInt64Bytes { + return + } + sortable := seq ^ math.MinInt64 + for i := sortableInt64Bytes - 1; i >= 0; i-- { + dst[i] = byte(sortable) + sortable >>= 8 + } +} + +func clampRange(start, end, length int) (int, int) { + if start < 0 { + start = length + start + } + if end < 0 { + end = length + end + } + if start < 0 { + start = 0 + } + if end >= length { + end = length - 1 + } + if end < start { + return 0, -1 + } + return start, end +} + +// sentinel seq for scan bounds +const ( + mathMinInt64 = -1 << 63 + mathMaxInt64 = 1<<63 - 1 +) + +// Exported helpers for other packages (e.g., Redis adapter). +func IsListMetaKey(key []byte) bool { return bytes.HasPrefix(key, []byte(ListMetaPrefix)) } + +func IsListItemKey(key []byte) bool { return bytes.HasPrefix(key, []byte(ListItemPrefix)) } + +// ExtractListUserKey returns the logical user key from a list meta or item key. +// If the key is not a list key, it returns nil. +func ExtractListUserKey(key []byte) []byte { + switch { + case IsListMetaKey(key): + return bytes.TrimPrefix(key, []byte(ListMetaPrefix)) + case IsListItemKey(key): + trimmed := bytes.TrimPrefix(key, []byte(ListItemPrefix)) + if len(trimmed) < sortableInt64Bytes { + return nil + } + return trimmed[:len(trimmed)-sortableInt64Bytes] + default: + return nil + } +} diff --git a/store/list_store_test.go b/store/list_store_test.go new file mode 100644 index 0000000..0900bcb --- /dev/null +++ b/store/list_store_test.go @@ -0,0 +1,152 @@ +package store + +import ( + "context" + "fmt" + "io" + "sync" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" +) + +func TestListStore_PutGet(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ls := NewListStore(NewRbMemoryStore()) + + in := []string{"a", "b", "c"} + assert.NoError(t, ls.PutList(ctx, []byte("k"), in)) + + out, err := ls.GetList(ctx, []byte("k")) + assert.NoError(t, err) + assert.Equal(t, in, out) +} + +func TestListStore_GetList_NotFound(t *testing.T) { + t.Parallel() + + _, err := NewListStore(NewRbMemoryStore()).GetList(context.Background(), []byte("missing")) + assert.ErrorIs(t, err, ErrKeyNotFound) +} + +func TestListStore_RPushAndRange(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ls := NewListStore(NewRbMemoryStore()) + + n, err := ls.RPush(ctx, []byte("numbers"), "zero", "one", "two", "three", "four") + assert.NoError(t, err) + assert.Equal(t, 5, n) + + // Range with positive indexes. + res, err := ls.Range(ctx, []byte("numbers"), 1, 3) + assert.NoError(t, err) + assert.Equal(t, []string{"one", "two", "three"}, res) + + // Range with negative end index. + res, err = ls.Range(ctx, []byte("numbers"), 2, -1) + assert.NoError(t, err) + assert.Equal(t, []string{"two", "three", "four"}, res) +} + +func TestListStore_Range_NotFound(t *testing.T) { + t.Parallel() + + _, err := NewListStore(NewRbMemoryStore()).Range(context.Background(), []byte("nope"), 0, -1) + assert.ErrorIs(t, err, ErrKeyNotFound) +} + +func TestListStore_RPushConcurrent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ls := NewListStore(NewRbMemoryStore()) + + wg := &sync.WaitGroup{} + const n = 50 + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, err := ls.RPush(ctx, []byte("k"), fmt.Sprintf("v-%d", i)) + assert.NoError(t, err) + }(i) + } + wg.Wait() + + list, err := ls.GetList(ctx, []byte("k")) + assert.NoError(t, err) + assert.Len(t, list, n) +} + +// failingScanStore simulates a transaction commit failure to verify atomicity. +type failingScanStore struct { + inner ScanStore + fail bool +} + +func newFailingScanStore(inner ScanStore, fail bool) *failingScanStore { + return &failingScanStore{inner: inner, fail: fail} +} + +func (s *failingScanStore) Get(ctx context.Context, key []byte) ([]byte, error) { + return s.inner.Get(ctx, key) +} + +func (s *failingScanStore) Put(ctx context.Context, key []byte, value []byte) error { + return s.inner.Put(ctx, key, value) +} + +func (s *failingScanStore) Delete(ctx context.Context, key []byte) error { + return s.inner.Delete(ctx, key) +} + +func (s *failingScanStore) Exists(ctx context.Context, key []byte) (bool, error) { + return s.inner.Exists(ctx, key) +} + +func (s *failingScanStore) Snapshot() (io.ReadWriter, error) { return nil, ErrNotSupported } +func (s *failingScanStore) Restore(io.Reader) error { return ErrNotSupported } +func (s *failingScanStore) Close() error { return nil } + +func (s *failingScanStore) Scan(ctx context.Context, start []byte, end []byte, limit int) ([]*KVPair, error) { + return s.inner.Scan(ctx, start, end, limit) +} + +// Txn executes the function; if fail is set, it aborts commit and returns an error. +func (s *failingScanStore) Txn(ctx context.Context, f func(ctx context.Context, txn Txn) error) error { + err := s.inner.Txn(ctx, func(ctx context.Context, txn Txn) error { + if s.fail { + return errors.New("injected commit failure") + } + return f(ctx, txn) + }) + + return err +} + +func TestListStore_PutList_RollbackOnTxnFailure(t *testing.T) { + t.Parallel() + + ctx := context.Background() + rawBase := NewRbMemoryStore() + ls := NewListStore(rawBase) + + initial := []string{"a", "b", "c"} + assert.NoError(t, ls.PutList(ctx, []byte("k"), initial)) + + failStore := newFailingScanStore(rawBase, true) + lsFail := NewListStore(failStore) + + err := lsFail.PutList(ctx, []byte("k"), []string{"x", "y"}) + assert.Error(t, err, "expected injected failure") + + // Original list must remain intact because txn never committed. + out, err := ls.GetList(ctx, []byte("k")) + assert.NoError(t, err) + assert.Equal(t, initial, out) +} diff --git a/store/rb_memory_store.go b/store/rb_memory_store.go index c6e967d..1815b6d 100644 --- a/store/rb_memory_store.go +++ b/store/rb_memory_store.go @@ -9,6 +9,7 @@ import ( "io" "log/slog" "os" + "sort" "sync" "time" @@ -384,8 +385,44 @@ func (t *rbMemoryStoreTxn) Scan(_ context.Context, start []byte, end []byte, lim t.mu.RLock() defer t.mu.RUnlock() - var result []*KVPair + if limit <= 0 { + return nil, nil + } + + deleted := t.deletedSet() + staged := t.stagedMap() + included := make(map[string]struct{}) + + result := make([]*KVPair, 0, limit) + t.addBaseResults(&result, included, start, end, limit, staged, deleted) + if len(result) < limit { + t.addStagedOnly(&result, included, start, end, limit, staged) + } + + sort.Slice(result, func(i, j int) bool { + return bytes.Compare(result[i].Key, result[j].Key) < 0 + }) + + if len(result) > limit { + result = result[:limit] + } + + return result, nil +} +// helper methods below assume t.mu is already RLocked by caller. +func (t *rbMemoryStoreTxn) deletedSet() map[string]struct{} { + deleted := make(map[string]struct{}, len(t.ops)) + for _, op := range t.ops { + if op.opType == OpTypeDelete { + deleted[string(op.key)] = struct{}{} + } + } + return deleted +} + +func (t *rbMemoryStoreTxn) stagedMap() map[string][]byte { + staged := make(map[string][]byte, t.tree.Size()) t.tree.Each(func(key interface{}, value interface{}) { k, ok := key.([]byte) if !ok { @@ -395,25 +432,66 @@ func (t *rbMemoryStoreTxn) Scan(_ context.Context, start []byte, end []byte, lim if !ok { return } + staged[string(k)] = v + }) + return staged +} - if bytes.Compare(k, start) < 0 { +func withinBounds(k, start, end []byte) bool { + if start != nil && bytes.Compare(k, start) < 0 { + return false + } + if end != nil && bytes.Compare(k, end) > 0 { + return false + } + return true +} + +func (t *rbMemoryStoreTxn) addBaseResults(result *[]*KVPair, included map[string]struct{}, start, end []byte, limit int, staged map[string][]byte, deleted map[string]struct{}) { + t.s.tree.Each(func(key interface{}, value interface{}) { + if len(*result) >= limit { return } - if bytes.Compare(k, end) > 0 { + k, ok := key.([]byte) + if !ok { + return + } + if !withinBounds(k, start, end) { + return + } + if _, deletedHere := deleted[string(k)]; deletedHere { return } - if len(result) >= limit { + v, ok := value.([]byte) + if !ok { return } + if stagedVal, ok := staged[string(k)]; ok { + v = stagedVal + } - result = append(result, &KVPair{ - Key: k, - Value: v, - }) + *result = append(*result, &KVPair{Key: k, Value: v}) + included[string(k)] = struct{}{} }) - return result, nil +} + +func (t *rbMemoryStoreTxn) addStagedOnly(result *[]*KVPair, included map[string]struct{}, start, end []byte, limit int, staged map[string][]byte) { + for kStr, v := range staged { + if len(*result) >= limit { + return + } + if _, already := included[kStr]; already { + continue + } + kb := []byte(kStr) + if !withinBounds(kb, start, end) { + continue + } + *result = append(*result, &KVPair{Key: kb, Value: v}) + included[kStr] = struct{}{} + } } func (t *rbMemoryStoreTxn) Put(_ context.Context, key []byte, value []byte) error {