From dffa945598e90ec8aa65e503e1cb566c582b0638 Mon Sep 17 00:00:00 2001 From: nachiketsrao Date: Fri, 21 Feb 2025 19:45:47 -0500 Subject: [PATCH 1/2] oramExecutor with grpc queuing for batches --- pkg/oramExecutor/oramExecutor.go | 149 ++++++++++++++++++++++++------- 1 file changed, 115 insertions(+), 34 deletions(-) diff --git a/pkg/oramExecutor/oramExecutor.go b/pkg/oramExecutor/oramExecutor.go index ca6f9ae92..ff3d25221 100644 --- a/pkg/oramExecutor/oramExecutor.go +++ b/pkg/oramExecutor/oramExecutor.go @@ -7,6 +7,8 @@ import ( "fmt" "os" "strings" + "sync" + "sync/atomic" executor "github.com/project/ObliSql/api/oramExecutor" // "github.com/redis/go-redis/v9" @@ -19,10 +21,32 @@ const ( stashSize = 100000 // Maximum number of blocks in stash ) +type Operation struct { + RequestID uint64 + Key string + Value string + Index int +} + +type ResponseTracker struct { + Values []string + Count int32 // Use atomic operations + Completed chan struct{} +} + type MyOram struct { executor.UnimplementedExecutorServer - // rdb *redis.Client o *ORAM + + opMutex sync.Mutex + opCond *sync.Cond + opQueue []Operation + + trackerMutex sync.Mutex + trackers map[uint64]*ResponseTracker + + requestID uint64 // atomic + batchSize int } type StringPair struct { @@ -30,53 +54,105 @@ type StringPair struct { Second string } -func (e MyOram) ExecuteBatch(ctx context.Context, req *executor.RequestBatchORAM) (*executor.RespondBatchORAM, error) { - fmt.Printf("Got a request with ID: %d \n", req.RequestId) +func (e *MyOram) ExecuteBatch(ctx context.Context, req *executor.RequestBatchORAM) (*executor.RespondBatchORAM, error) { + if len(req.Keys) != len(req.Values) { + return nil, fmt.Errorf("keys and values length mismatch") + } - // set batchsize - batchSize := 60 + // Generate unique request ID + id := atomic.AddUint64(&e.requestID, 1) + numOps := len(req.Keys) - // Batching(requests []request.Request, batchSize int) + // Setup response tracker + tracker := &ResponseTracker{ + Values: make([]string, numOps), + Count: 0, + Completed: make(chan struct{}), + } - var replyKeys []string - var replyVals []string + // Register tracker + e.trackerMutex.Lock() + e.trackers[id] = tracker + e.trackerMutex.Unlock() + + // Queue operations + e.opMutex.Lock() + for i := 0; i < numOps; i++ { + e.opQueue = append(e.opQueue, Operation{ + RequestID: id, + Key: req.Keys[i], + Value: req.Values[i], + Index: i, + }) + } + e.opCond.Signal() // Notify batch processor + e.opMutex.Unlock() - for start := 0; start < len(req.Values); start += batchSize { + // Wait for completion + <-tracker.Completed - var requestList []Request - var returnValues []string + // Return response with original request ID + return &executor.RespondBatchORAM{ + RequestId: req.RequestId, + Keys: req.Keys, + Values: tracker.Values, + }, nil +} - end := start + batchSize - if end > len(req.Values) { - end = len(req.Values) // Ensure we don't go out of bounds +func (e *MyOram) processBatches() { + for { + e.opMutex.Lock() + // Wait for operations to process + for len(e.opQueue) == 0 { + e.opCond.Wait() } - // Slice the keys and values for the current batch - batchKeys := req.Keys[start:end] - batchValues := req.Values[start:end] + // Determine batch size + batchSize := e.batchSize + if len(e.opQueue) < batchSize { + e.opCond.Wait() + } + batchOps := e.opQueue[:batchSize] + e.opQueue = e.opQueue[batchSize:] + e.opMutex.Unlock() - for i := range batchKeys { - // Read operation - currentRequest := Request{ - Key: batchKeys[i], - Value: batchValues[i], - } + // Prepare ORAM batch request + var requestList []Request + for _, op := range batchOps { + requestList = append(requestList, Request{ + Key: op.Key, + Value: op.Value, + }) + } - requestList = append(requestList, currentRequest) + // Execute ORAM batch + returnValues, err := e.o.Batching(requestList, batchSize) + if err != nil { + // Handle error (e.g., log and continue) + fmt.Printf("ORAM batch error: %v\n", err) + continue } - returnValues, _ = e.o.Batching(requestList, batchSize) + // Distribute responses to trackers + e.trackerMutex.Lock() + for i, op := range batchOps { + tracker, exists := e.trackers[op.RequestID] + if !exists { + continue // Tracker already removed + } - replyKeys = append(replyKeys, batchKeys...) - replyVals = append(replyVals, returnValues...) + // Update value atomically + tracker.Values[op.Index] = returnValues[i] + count := atomic.AddInt32(&tracker.Count, 1) + // Check if all responses received + if int(count) == len(tracker.Values) { + close(tracker.Completed) + delete(e.trackers, op.RequestID) + } + } + e.trackerMutex.Unlock() } - - return &executor.RespondBatchORAM{ - RequestId: req.RequestId, - Keys: replyKeys, - Values: replyVals, - }, nil } func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, useSnapshot bool, key []byte) (*MyOram, error) { @@ -167,8 +243,13 @@ func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, } myOram := &MyOram{ - o: oram, + o: oram, + batchSize: 60, // Set from config or constant + trackers: make(map[uint64]*ResponseTracker), + opQueue: make([]Operation, 0), } + myOram.opCond = sync.NewCond(&myOram.opMutex) + go myOram.processBatches() // Start batch processing return myOram, nil } From e2bfaf57129ae44ea45719227a12cac63d42d555 Mon Sep 17 00:00:00 2001 From: nachiketsrao Date: Tue, 25 Feb 2025 23:57:55 -0500 Subject: [PATCH 2/2] New change with channels --- pkg/oramExecutor/oramExecutor.go | 190 +++++++++++++++++-------------- 1 file changed, 103 insertions(+), 87 deletions(-) diff --git a/pkg/oramExecutor/oramExecutor.go b/pkg/oramExecutor/oramExecutor.go index 580998c17..16377d544 100644 --- a/pkg/oramExecutor/oramExecutor.go +++ b/pkg/oramExecutor/oramExecutor.go @@ -28,25 +28,28 @@ type Operation struct { Index int } -type ResponseTracker struct { - Values []string - Count int32 // Use atomic operations - Completed chan struct{} +type KVPair struct { + channelId string + Key string + Value string +} + +type responseChannel struct { + m *sync.RWMutex + channel chan KVPair } type MyOram struct { executor.UnimplementedExecutorServer o *ORAM - opMutex sync.Mutex - opCond *sync.Cond - opQueue []Operation + batchSize int - trackerMutex sync.Mutex - trackers map[uint64]*ResponseTracker + channelMap map[string]responseChannel + requestNumber atomic.Int64 + channelLock sync.RWMutex - requestID uint64 // atomic - batchSize int + oramExecutorChannel chan *KVPair } type tempBlock struct { @@ -63,99 +66,108 @@ func (e *MyOram) ExecuteBatch(ctx context.Context, req *executor.RequestBatchORA return nil, fmt.Errorf("keys and values length mismatch") } - // Generate unique request ID - id := atomic.AddUint64(&e.requestID, 1) - numOps := len(req.Keys) + reqNum := e.requestNumber.Add(1) // New id for this client/batch channel + + recv_resp := make([]KVPair, 0, len(req.Keys)) // This will store completed key value pairs + + channelId := fmt.Sprintf("%d-%d", req.RequestId, reqNum) + localRespChannel := make(chan KVPair, len(req.Keys)) - // Setup response tracker - tracker := &ResponseTracker{ - Values: make([]string, numOps), - Count: 0, - Completed: make(chan struct{}), + e.channelLock.Lock() // Add channel to global map + e.channelMap[channelId] = responseChannel{ + m: &sync.RWMutex{}, + channel: localRespChannel, + } + e.channelLock.Unlock() + + sent := 0 + for i, key := range req.Keys { + value := req.Values[i] + kv := &KVPair{ + channelId: channelId, + Key: key, + Value: value, + } + // Block if the channel is full + sent++ + e.oramExecutorChannel <- kv } - // Register tracker - e.trackerMutex.Lock() - e.trackers[id] = tracker - e.trackerMutex.Unlock() - - // Queue operations - e.opMutex.Lock() - for i := 0; i < numOps; i++ { - e.opQueue = append(e.opQueue, Operation{ - RequestID: id, - Key: req.Keys[i], - Value: req.Values[i], - Index: i, - }) + // Finished adding keys to ORAM channel + + // Now wait for responses + for i := 0; i < len(req.Keys); i++ { + item := <-localRespChannel + recv_resp = append(recv_resp, item) } - e.opCond.Signal() // Notify batch processor - e.opMutex.Unlock() - // Wait for completion - <-tracker.Completed + close(localRespChannel) + + e.channelLock.Lock() + delete(e.channelMap, channelId) + e.channelLock.Unlock() + + sendKeys := make([]string, 0, len(req.Keys)) + sendVal := make([]string, 0, len(req.Keys)) + + for _, v := range recv_resp { + sendKeys = append(sendKeys, v.Key) + sendVal = append(sendVal, v.Value) + } // Return response with original request ID return &executor.RespondBatchORAM{ RequestId: req.RequestId, - Keys: req.Keys, - Values: tracker.Values, + Keys: sendKeys, + Values: sendVal, }, nil } func (e *MyOram) processBatches() { for { - e.opMutex.Lock() - // Wait for operations to process - for len(e.opQueue) == 0 { - e.opCond.Wait() - } - // Determine batch size - batchSize := e.batchSize - if len(e.opQueue) < batchSize { - e.opCond.Wait() - } - batchOps := e.opQueue[:batchSize] - e.opQueue = e.opQueue[batchSize:] - e.opMutex.Unlock() - - // Prepare ORAM batch request - var requestList []Request - for _, op := range batchOps { - requestList = append(requestList, Request{ - Key: op.Key, - Value: op.Value, - }) - } + if len(e.oramExecutorChannel) >= e.batchSize { + var requestList []Request - // Execute ORAM batch - returnValues, err := e.o.Batching(requestList, batchSize) - if err != nil { - // Handle error (e.g., log and continue) - fmt.Printf("ORAM batch error: %v\n", err) - continue - } + var chanIds []string - // Distribute responses to trackers - e.trackerMutex.Lock() - for i, op := range batchOps { - tracker, exists := e.trackers[op.RequestID] - if !exists { - continue // Tracker already removed + for i := 0; i < e.batchSize; i++ { + op := <-e.oramExecutorChannel // Read from channel + + chanIds = append(chanIds, op.channelId) + + requestList = append(requestList, Request{ + Key: op.Key, + Value: op.Value, + }) + } + // Execute ORAM batch + returnValues, err := e.o.Batching(requestList, e.batchSize) + if err != nil { + // Handle error (e.g., log and continue) + fmt.Printf("ORAM batch error: %v\n", err) + continue } - // Update value atomically - tracker.Values[op.Index] = returnValues[i] - count := atomic.AddInt32(&tracker.Count, 1) + channelCache := make(map[string]chan KVPair, e.batchSize) + + e.channelLock.RLock() + for _, v := range chanIds { + + channelCache[v] = e.channelMap[v].channel + + } + e.channelLock.RUnlock() - // Check if all responses received - if int(count) == len(tracker.Values) { - close(tracker.Completed) - delete(e.trackers, op.RequestID) + for i := 0; i < e.batchSize; i++ { + newKVPair := KVPair{ + Key: requestList[i].Key, + Value: returnValues[i], + } + responseChannel := channelCache[chanIds[i]] + responseChannel <- newKVPair } } - e.trackerMutex.Unlock() } } @@ -187,6 +199,7 @@ func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, // Load the Stashmap and Keymap into memory // Allow redis to update state using dump.rdb oram.loadSnapshotMaps() + fmt.Println("ORAM snapshot loaded successfully!") } else { // Clear the Redis database to ensure a fresh start if err := client.FlushDB(); err != nil { @@ -247,12 +260,15 @@ func NewORAM(LogCapacity, Z, StashSize int, redisAddr string, tracefile string, } myOram := &MyOram{ - o: oram, - batchSize: 60, // Set from config or constant - trackers: make(map[uint64]*ResponseTracker), - opQueue: make([]Operation, 0), + o: oram, + batchSize: 60, // Set from config or constant + channelMap: make(map[string]responseChannel), + channelLock: sync.RWMutex{}, + oramExecutorChannel: make(chan *KVPair), } - myOram.opCond = sync.NewCond(&myOram.opMutex) + + myOram.oramExecutorChannel = make(chan *KVPair, 100000) + go myOram.processBatches() // Start batch processing return myOram, nil