diff --git a/pkg/oramExecutor/oramExecutor.go b/pkg/oramExecutor/oramExecutor.go index 00fcf509c..16377d544 100644 --- a/pkg/oramExecutor/oramExecutor.go +++ b/pkg/oramExecutor/oramExecutor.go @@ -7,6 +7,8 @@ import ( "fmt" "os" "strings" + "sync" + "sync/atomic" executor "github.com/project/ObliSql/api/oramExecutor" // "github.com/redis/go-redis/v9" @@ -19,10 +21,35 @@ const ( stashSize = 100000 // Maximum number of blocks in stash ) +type Operation struct { + RequestID uint64 + Key string + Value string + Index int +} + +type KVPair struct { + channelId string + Key string + Value string +} + +type responseChannel struct { + m *sync.RWMutex + channel chan KVPair +} + type MyOram struct { executor.UnimplementedExecutorServer - // rdb *redis.Client o *ORAM + + batchSize int + + channelMap map[string]responseChannel + requestNumber atomic.Int64 + channelLock sync.RWMutex + + oramExecutorChannel chan *KVPair } type tempBlock struct { @@ -34,55 +61,116 @@ type StringPair struct { Second string } -func (e MyOram) ExecuteBatch(ctx context.Context, req *executor.RequestBatchORAM) (*executor.RespondBatchORAM, error) { - fmt.Printf("Got a request with ID: %d \n", req.RequestId) - - // set batchsize - batchSize := 60 - - // Batching(requests []request.Request, batchSize int) +func (e *MyOram) ExecuteBatch(ctx context.Context, req *executor.RequestBatchORAM) (*executor.RespondBatchORAM, error) { + if len(req.Keys) != len(req.Values) { + return nil, fmt.Errorf("keys and values length mismatch") + } - var replyKeys []string - var replyVals []string + reqNum := e.requestNumber.Add(1) // New id for this client/batch channel - for start := 0; start < len(req.Values); start += batchSize { + recv_resp := make([]KVPair, 0, len(req.Keys)) // This will store completed key value pairs - var requestList []Request - var returnValues []string + channelId := fmt.Sprintf("%d-%d", req.RequestId, reqNum) + localRespChannel := make(chan KVPair, len(req.Keys)) - end := start + batchSize - if end > len(req.Values) { - end = len(req.Values) // Ensure we don't go out of bounds + e.channelLock.Lock() // Add channel to global map + e.channelMap[channelId] = responseChannel{ + m: &sync.RWMutex{}, + channel: localRespChannel, + } + e.channelLock.Unlock() + + sent := 0 + for i, key := range req.Keys { + value := req.Values[i] + kv := &KVPair{ + channelId: channelId, + Key: key, + Value: value, } + // Block if the channel is full + sent++ + e.oramExecutorChannel <- kv + } - // Slice the keys and values for the current batch - batchKeys := req.Keys[start:end] - batchValues := req.Values[start:end] + // Finished adding keys to ORAM channel - for i := range batchKeys { - // Read operation - currentRequest := Request{ - Key: batchKeys[i], - Value: batchValues[i], - } + // Now wait for responses + for i := 0; i < len(req.Keys); i++ { + item := <-localRespChannel + recv_resp = append(recv_resp, item) + } - requestList = append(requestList, currentRequest) - } + close(localRespChannel) - returnValues, _ = e.o.Batching(requestList, batchSize) + e.channelLock.Lock() + delete(e.channelMap, channelId) + e.channelLock.Unlock() - replyKeys = append(replyKeys, batchKeys...) - replyVals = append(replyVals, returnValues...) + sendKeys := make([]string, 0, len(req.Keys)) + sendVal := make([]string, 0, len(req.Keys)) + for _, v := range recv_resp { + sendKeys = append(sendKeys, v.Key) + sendVal = append(sendVal, v.Value) } + // Return response with original request ID return &executor.RespondBatchORAM{ RequestId: req.RequestId, - Keys: replyKeys, - Values: replyVals, + Keys: sendKeys, + Values: sendVal, }, nil } +func (e *MyOram) processBatches() { + for { + + if len(e.oramExecutorChannel) >= e.batchSize { + var requestList []Request + + var chanIds []string + + for i := 0; i < e.batchSize; i++ { + op := <-e.oramExecutorChannel // Read from channel + + chanIds = append(chanIds, op.channelId) + + requestList = append(requestList, Request{ + Key: op.Key, + Value: op.Value, + }) + } + // Execute ORAM batch + returnValues, err := e.o.Batching(requestList, e.batchSize) + if err != nil { + // Handle error (e.g., log and continue) + fmt.Printf("ORAM batch error: %v\n", err) + continue + } + + channelCache := make(map[string]chan KVPair, e.batchSize) + + e.channelLock.RLock() + for _, v := range chanIds { + + channelCache[v] = e.channelMap[v].channel + + } + e.channelLock.RUnlock() + + for i := 0; i < e.batchSize; i++ { + newKVPair := KVPair{ + Key: requestList[i].Key, + Value: returnValues[i], + } + responseChannel := channelCache[chanIds[i]] + responseChannel <- newKVPair + } + } + } +} + func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, useSnapshot bool, key []byte) (*MyOram, error) { // If key is not provided (nil or empty), generate a random key if len(key) == 0 { @@ -111,6 +199,7 @@ func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, // Load the Stashmap and Keymap into memory // Allow redis to update state using dump.rdb oram.loadSnapshotMaps() + fmt.Println("ORAM snapshot loaded successfully!") } else { // Clear the Redis database to ensure a fresh start if err := client.FlushDB(); err != nil { @@ -171,9 +260,17 @@ func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, } myOram := &MyOram{ - o: oram, + o: oram, + batchSize: 60, // Set from config or constant + channelMap: make(map[string]responseChannel), + channelLock: sync.RWMutex{}, + oramExecutorChannel: make(chan *KVPair), } + myOram.oramExecutorChannel = make(chan *KVPair, 100000) + + go myOram.processBatches() // Start batch processing + return myOram, nil }