@@ -16,6 +16,7 @@ import jsonrpclib.internals.MessageDispatcher
1616import jsonrpclib .internals ._
1717
1818import scala .util .Try
19+ import java .util .regex .Pattern
1920
2021trait FS2Channel [F [_]] extends Channel [F ] {
2122
@@ -52,7 +53,7 @@ object FS2Channel {
5253 ): Stream [F , FS2Channel [F ]] = {
5354 for {
5455 supervisor <- Stream .resource(Supervisor [F ])
55- ref <- Ref [F ].of(State [F ](Map .empty, Map .empty, Map .empty, 0 )).toStream
56+ ref <- Ref [F ].of(State [F ](Map .empty, Map .empty, Map .empty, Vector .empty, 0 )).toStream
5657 queue <- cats.effect.std.Queue .bounded[F , Payload ](bufferSize).toStream
5758 impl = new Impl (queue, ref, supervisor, cancelTemplate)
5859
@@ -73,6 +74,7 @@ object FS2Channel {
7374 runningCalls : Map [CallId , Fiber [F , Throwable , Unit ]],
7475 pendingCalls : Map [CallId , OutputMessage => F [Unit ]],
7576 endpoints : Map [String , Endpoint [F ]],
77+ globEndpoints : Vector [(Pattern , Endpoint [F ])],
7678 counter : Long
7779 ) {
7880 def nextCallId : (State [F ], CallId ) = (this .copy(counter = counter + 1 ), CallId .NumberId (counter))
@@ -82,11 +84,27 @@ object FS2Channel {
8284 val result = pendingCalls.get(callId)
8385 (this .copy(pendingCalls = pendingCalls.removed(callId)), result)
8486 }
85- def mountEndpoint (endpoint : Endpoint [F ]): Either [ConflictingMethodError , State [F ]] =
86- endpoints.get(endpoint.method) match {
87- case None => Right (this .copy(endpoints = endpoints + (endpoint.method -> endpoint)))
88- case Some (_) => Left (ConflictingMethodError (endpoint.method))
87+ def mountEndpoint (endpoint : Endpoint [F ]): Either [ConflictingMethodError , State [F ]] = {
88+ import endpoint .method
89+ if (method.contains(" *" )) {
90+ val parts = method
91+ .split(" \\ *" , - 1 )
92+ .map { // Don't discard trailing empty string, if any.
93+ case " " => " "
94+ case str => Pattern .quote(str)
95+ }
96+ val glob = Pattern .compile(parts.mkString(" .*" ))
97+ Right (this .copy(globEndpoints = globEndpoints :+ (glob -> endpoint)))
98+ } else {
99+ endpoints.get(endpoint.method) match {
100+ case None => Right (this .copy(endpoints = endpoints + (endpoint.method -> endpoint)))
101+ case Some (_) => Left (ConflictingMethodError (endpoint.method))
102+ }
89103 }
104+ }
105+ def getEndpoint (method : String ): Option [Endpoint [F ]] = {
106+ endpoints.get(method).orElse(globEndpoints.find(_._1.matcher(method).matches()).map(_._2))
107+ }
90108 def removeEndpoint (method : String ): State [F ] =
91109 copy(endpoints = endpoints.removed(method))
92110
@@ -135,7 +153,7 @@ object FS2Channel {
135153 }
136154 }
137155 protected def reportError (params : Option [Payload ], error : ProtocolError , method : String ): F [Unit ] = ???
138- protected def getEndpoint (method : String ): F [Option [Endpoint [F ]]] = state.get.map(_.endpoints.get (method))
156+ protected def getEndpoint (method : String ): F [Option [Endpoint [F ]]] = state.get.map(_.getEndpoint (method))
139157 protected def sendMessage (message : Message ): F [Unit ] = queue.offer(Codec .encode(message))
140158
141159 protected def nextCallId (): F [CallId ] = state.modify(_.nextCallId)
0 commit comments