Skip to content

Commit ba2d93a

Browse files
black-dragon74mergify[bot]
authored andcommitted
connection: Disconnect idle gRPC connections on inactivity
This patch adds support to close the lingering gRPC connections that have not been used for a certain amount of time, default to 5 mins. An interceptor is used to track the last access time in a reliable manner. If a expired connection is requested, it will first be re-created and then returned to the caller. Since gRPC connections are not simple disconnects but a complete tear down, re-creating the connections is necessary. Signed-off-by: Niraj Yadav <niryadav@redhat.com>
1 parent ef1bf8b commit ba2d93a

File tree

3 files changed

+199
-36
lines changed

3 files changed

+199
-36
lines changed

internal/connection/connection.go

Lines changed: 135 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package connection
1919
import (
2020
"context"
2121
"crypto/tls"
22+
"sync"
23+
"sync/atomic"
2224
"time"
2325

2426
"github.com/csi-addons/kubernetes-csi-addons/internal/kubernetes/token"
@@ -29,23 +31,67 @@ import (
2931
"google.golang.org/grpc/credentials/insecure"
3032
)
3133

34+
const (
35+
// The duration after which a gRPC connection is closed due to inactivity
36+
idleTimeout = time.Minute * 5
37+
)
38+
3239
// Connection struct consists of to NodeID, DriverName, Capabilities for the controller
3340
// to pick sidecar connection and grpc Client to connect to the sidecar.
3441
type Connection struct {
42+
sync.Mutex
43+
3544
Client *grpc.ClientConn
3645
Capabilities []*identity.Capability
3746
Namespace string
3847
Name string
3948
NodeID string
4049
DriverName string
4150
Timeout time.Duration
51+
52+
// Holds the internal state of the connection
53+
connected bool
54+
enableAuth bool
55+
endpoint string
56+
podName string
57+
58+
// Used to cancel any existing timers in case of re-connects
59+
cancelIdle context.CancelFunc
60+
// Holds the last access time of the gRPC Connection
61+
// It is tracked and updated by the `accessTimeInterceptor`
62+
lastAccessTime atomic.Int64
4263
}
4364

44-
// NewConnection establishes connection with sidecar, fetches capability and returns Connection object
45-
// filled with required information.
46-
func NewConnection(ctx context.Context, endpoint, nodeID, driverName, namespace, podName string, enableAuth bool) (*Connection, error) {
47-
var opts []grpc.DialOption
48-
if enableAuth {
65+
// accessTimeInterceptor is an unary interceptor which updates the lastAccessTime
66+
// on `Connection` struct to `time.Now()` on each RPC call
67+
func accessTimeInterceptor(conn *Connection) grpc.UnaryClientInterceptor {
68+
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
69+
conn.lastAccessTime.Store(time.Now().UnixNano())
70+
71+
return invoker(ctx, method, req, reply, cc, opts...)
72+
}
73+
}
74+
75+
// Connect creates a new grpc.ClientConn object and sets it as the
76+
// client property on Connection struct. If a connection is already
77+
// connected, it is reused as is. It also spawns a go routine to tear
78+
// down the connection if it has been idling for a specified threshold.
79+
//
80+
// In cases where a new connection is created from the scratch Connect
81+
// also calls fetchCapabilities on the connection object.
82+
func (c *Connection) Connect() error {
83+
c.Lock()
84+
defer c.Unlock()
85+
86+
// Return early if already connected, the connection will be reused
87+
if c.connected {
88+
return nil
89+
}
90+
91+
opts := []grpc.DialOption{
92+
grpc.WithUnaryInterceptor(accessTimeInterceptor(c)),
93+
}
94+
if c.enableAuth {
4995
opts = append(opts, token.WithServiceAccountToken())
5096
tlsConfig := &tls.Config{
5197
// Certs are only used to initiate HTTPS connections; authorization is handled by SA tokens
@@ -56,34 +102,110 @@ func NewConnection(ctx context.Context, endpoint, nodeID, driverName, namespace,
56102
} else {
57103
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
58104
}
59-
cc, err := grpc.NewClient(endpoint, opts...)
105+
cc, err := grpc.NewClient(c.endpoint, opts...)
60106
if err != nil {
61-
return nil, err
107+
return err
108+
}
109+
c.Client = cc
110+
c.connected = true
111+
112+
// Fetch the caps
113+
// Downside is we need to do this on every re-connect
114+
if err := c.fetchCapabilities(context.Background()); err != nil {
115+
if e := c.Client.Close(); e == nil {
116+
c.connected = false
117+
}
118+
119+
return err
120+
}
121+
122+
// Start a goroutine to close the connection after idle timeout
123+
// But first, expire any existing timers
124+
if c.cancelIdle != nil {
125+
c.cancelIdle()
62126
}
127+
idleCtx, cFunc := context.WithCancel(context.Background())
128+
c.cancelIdle = cFunc
129+
go c.startIdleTimer(idleCtx)
130+
131+
return nil
132+
}
133+
134+
// startIdleTimer starts a ticker with an interval of 30 seconds
135+
// At each tick, it checks if the connection has been idle for
136+
// more than `idleTimeout`. If so, the connection is closed.
137+
func (c *Connection) startIdleTimer(ctx context.Context) {
138+
ticker := time.NewTicker(time.Second * 30)
139+
defer ticker.Stop()
140+
141+
for {
142+
select {
143+
case <-ticker.C:
144+
c.Lock()
145+
146+
lastAccess := time.Unix(0, c.lastAccessTime.Load())
147+
isIdle := time.Since(lastAccess) > idleTimeout
148+
if isIdle && c.connected {
149+
// It's okay if there's an error in tearing down the connection
150+
// It will be reused by subsequent requests, see Connect() for details
151+
if err := c.Client.Close(); err == nil {
152+
c.connected = false
153+
}
154+
}
155+
c.Unlock()
156+
157+
case <-ctx.Done():
158+
// Timer was cancelled
159+
return
160+
}
161+
}
162+
}
163+
164+
// NewConnection establishes connection with sidecar, fetches capability and returns Connection object
165+
// filled with required information.
166+
func NewConnection(ctx context.Context, endpoint, nodeID, driverName, namespace, podName string, enableAuth bool) (*Connection, error) {
63167

64168
conn := &Connection{
65-
Client: cc,
66169
Namespace: namespace,
67170
Name: podName,
68171
NodeID: nodeID,
69172
DriverName: driverName,
70173
Timeout: time.Minute,
174+
175+
endpoint: endpoint,
176+
podName: podName,
177+
enableAuth: enableAuth,
71178
}
72179

73-
err = conn.fetchCapabilities(ctx)
74-
if err != nil {
180+
if err := conn.Connect(); err != nil {
75181
return nil, err
76182
}
77183

78184
return conn, nil
79185
}
80186

187+
// Close tears down the gRPC connection and terminates the goroutines
188+
// monitoring the idle timeout by calling cancelIdle()
81189
func (c *Connection) Close() error {
82-
if c.Client == nil {
83-
return nil
190+
c.Lock()
191+
defer c.Unlock()
192+
193+
if c.cancelIdle != nil {
194+
// Cancel the context as well
195+
c.cancelIdle()
196+
c.cancelIdle = nil
84197
}
85198

86-
return c.Client.Close()
199+
if c.Client != nil {
200+
err := c.Client.Close()
201+
if err == nil {
202+
c.connected = false
203+
}
204+
205+
return err
206+
}
207+
208+
return nil
87209
}
88210

89211
// fetchCapabilities fetches the capability of the connected CSI driver.

internal/connection/connection_pool.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import (
2828
"github.com/csi-addons/kubernetes-csi-addons/internal/util"
2929
)
3030

31+
const failedToReconnectFmtStr = "failed to reconnect an inactive connection due to error: %w"
32+
3133
//+kubebuilder:rbac:groups=coordination.k8s.io,resources=leases,verbs=get;list;watch
3234

3335
// ConnectionPool consists of map of Connection objects and
@@ -82,34 +84,45 @@ func (cp *ConnectionPool) Delete(key string) {
8284

8385
// getByDriverName returns map of connections filtered by driverName. This function
8486
// must be called with read lock held.
85-
func (cp *ConnectionPool) getByDriverName(driverName string) map[string]*Connection {
87+
func (cp *ConnectionPool) getByDriverName(driverName string) (map[string]*Connection, error) {
8688
newPool := make(map[string]*Connection)
8789
for k, v := range cp.pool {
8890
if v.DriverName != driverName {
8991
continue
9092
}
93+
if err := v.Connect(); err != nil {
94+
return nil, fmt.Errorf(failedToReconnectFmtStr, err)
95+
}
9196
newPool[k] = v
9297
}
9398

94-
return newPool
99+
return newPool, nil
95100
}
96101

97102
// GetByNodeID returns map of connections, filtered with given driverName and optional nodeID.
98-
func (cp *ConnectionPool) GetByNodeID(driverName, nodeID string) map[string]*Connection {
103+
func (cp *ConnectionPool) GetByNodeID(driverName, nodeID string) (map[string]*Connection, error) {
99104
cp.rwlock.RLock()
100105
defer cp.rwlock.RUnlock()
101106

102-
pool := cp.getByDriverName(driverName)
107+
pool, err := cp.getByDriverName(driverName)
108+
if err != nil {
109+
return nil, err
110+
}
103111
result := make(map[string]*Connection)
104112
for k, v := range pool {
105113
// since nodeID is options,check only if it is not empty
106114
if nodeID != "" && v.NodeID != nodeID {
107115
continue
108116
}
117+
// We may skip this one as the validation is done in `getByDriverName` as well
118+
// But since this is just checking a boolean, leaving it as-is is fine too.
119+
if err = v.Connect(); err != nil {
120+
return nil, fmt.Errorf(failedToReconnectFmtStr, err)
121+
}
109122
result[k] = v
110123
}
111124

112-
return result
125+
return result, nil
113126
}
114127

115128
// getNamespaceByDriverName loops through the connections in the pool and
@@ -166,5 +179,9 @@ func (cp *ConnectionPool) GetLeaderByDriver(ctx context.Context, reconciler clie
166179
return nil, fmt.Errorf("no connection with key %q found for driver %q: %w", key, driverName, err)
167180
}
168181

182+
if err = conn.Connect(); err != nil {
183+
return nil, fmt.Errorf(failedToReconnectFmtStr, err)
184+
}
185+
169186
return conn, nil
170187
}

0 commit comments

Comments
 (0)