Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/Simplex/Messaging/Agent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,15 @@ import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.AgentStore
import Simplex.Messaging.Agent.Store.Common (DBStore)
import qualified Simplex.Messaging.Agent.Store.DB as DB
import Simplex.Messaging.Agent.Store.Entity
import Simplex.Messaging.Agent.Store.Interface (closeDBStore, execSQL, getCurrentMigrations)
import Simplex.Messaging.Agent.Store.Shared (UpMigration (..), upMigration)
import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, nonBlockingWriteTBQueue, temporaryClientError, unexpectedResponse)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs)
import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn)
import qualified Simplex.Messaging.Crypto.ShortLink as SL
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import qualified Simplex.Messaging.Crypto.ShortLink as SL
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..), NtfTokenId, PNMessageData (..), pnMessagesP)
Expand Down Expand Up @@ -217,7 +218,6 @@ import Simplex.Messaging.Protocol
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.ServiceScheme (ServiceScheme (..))
import Simplex.Messaging.Agent.Store.Entity
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPVersion)
import Simplex.Messaging.Util
Expand Down Expand Up @@ -833,8 +833,9 @@ newConn :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> B
newConn c nm userId enableNtfs cMode userData_ clientData pqInitKeys subMode = do
srv <- getSMPServer c userId
connId <- newConnNoQueues c userId enableNtfs cMode (CR.connPQEncryption pqInitKeys)
(connId,) <$> newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKeys subMode srv
`catchE` \e -> withStore' c (`deleteConnRecord` connId) >> throwE e
(connId,)
<$> newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKeys subMode srv
`catchE` \e -> withStore' c (`deleteConnRecord` connId) >> throwE e

setConnShortLink' :: AgentClient -> NetworkRequestMode -> ConnId -> SConnectionMode c -> UserLinkData -> Maybe CRClientData -> AM (ConnShortLink c)
setConnShortLink' c nm connId cMode userData clientData =
Expand Down Expand Up @@ -914,8 +915,7 @@ getConnShortLink' c nm userId = \case
decryptData srv linkKey k (sndId, d) = do
r@(cReq, clData) <- liftEither $ SL.decryptLinkData @c linkKey k d
let (srv', sndId') = qAddress (connReqQueue cReq)
unless (srv `sameSrvHost` srv' && sndId == sndId') $
throwE $ AGENT $ A_LINK "different address"
unless (srv `sameSrvHost` srv' && sndId == sndId') $ throwE $ AGENT $ A_LINK "different address"
pure $ if srv' == srv then r else (updateConnReqServer srv cReq, clData)
sameSrvHost ProtocolServer {host = h :| _} ProtocolServer {host = hs} = h `elem` hs
updateConnReqServer :: SMPServer -> ConnectionRequestUri c -> ConnectionRequestUri c
Expand Down Expand Up @@ -1004,7 +1004,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey
connReqWithShortLink :: SMPQueueUri -> ConnectionRequestUri c -> SMPQueueUri -> Maybe ShortLinkCreds -> AM (CreatedConnLink c)
connReqWithShortLink qUri cReq qUri' shortLink = case shortLink of
Just ShortLinkCreds {shortLinkId, shortLinkKey}
| qUri == qUri' -> pure $ case cReq of
| qUri == qUri' -> pure $ case cReq of
CRContactUri _ -> CCLink cReq $ Just $ CSLContact SLSServer CCTContact srv shortLinkKey
CRInvitationUri crData (CR.E2ERatchetParamsUri vr k1 k2 _) ->
let cReq' = case pqInitKeys of
Expand Down Expand Up @@ -1682,7 +1682,7 @@ enqueueMessageB c reqs = do
storeSentMsg db cfg aMessageIds = \case
Left e -> pure (aMessageIds, Left e)
Right req@(csqs_, pqEnc_, msgFlags, mbr) -> case mbr of
VRValue i_ aMessage -> case i_ >>= (`IM.lookup` aMessageIds) of
VRValue i_ aMessage -> case i_ >>= (`IM.lookup` aMessageIds) of
Just _ -> pure (aMessageIds, Left $ INTERNAL "enqueueMessageB: storeSentMsg duplicate saved message body")
Nothing -> do
(mbId_, r) <- case csqs_ of
Expand Down Expand Up @@ -1724,7 +1724,6 @@ enqueueMessageB c reqs = do
handleInternal :: E.SomeException -> IO (Either AgentErrorType b)
handleInternal = pure . Left . INTERNAL . show


encodeAgentMsgStr :: AMessage -> InternalSndId -> PrevSndMsgHash -> ByteString
encodeAgentMsgStr aMessage internalSndId prevMsgHash = do
let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash
Expand Down Expand Up @@ -2536,10 +2535,13 @@ getNextSMPServer c userId = getNextServer c userId storageSrvs
{-# INLINE getNextSMPServer #-}

subscriber :: AgentClient -> AM' ()
subscriber c@AgentClient {msgQ} = forever $ do
subscriber c@AgentClient {msgQ, subQ} = run $ forever $ do
t <- atomically $ readTBQueue msgQ
agentOperationBracket c AORcvNetwork waitUntilActive $
processSMPTransmissions c t
where
run a = a `catchOwn` \e -> notify $ CRITICAL True $ "Agent subscriber stopped: " <> show e
notify err = atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR err)

cleanupManager :: AgentClient -> AM' ()
cleanupManager c@AgentClient {subQ} = do
Expand Down Expand Up @@ -2848,7 +2850,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
ackDel :: InternalId -> AM ACKd
ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd
handleNotifyAck :: AM ACKd -> AM ACKd
handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack
handleNotifyAck m = m `catchAllOwnErrors` \e -> notify (ERR e) >> ack
SMP.END ->
atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False))
>>= notifyEnd
Expand Down
6 changes: 4 additions & 2 deletions src/Simplex/Messaging/Agent/Protocol.hs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ module Simplex.Messaging.Agent.Protocol
where

import Control.Applicative (optional, (<|>))
import Control.Exception (BlockedIndefinitelyOnSTM (..), fromException)
import Control.Exception (BlockedIndefinitelyOnMVar (..), BlockedIndefinitelyOnSTM (..), fromException)
import Data.Aeson (FromJSON (..), ToJSON (..), Value (..), (.:), (.:?))
import qualified Data.Aeson as J'
import qualified Data.Aeson.Encoding as JE
Expand Down Expand Up @@ -1870,7 +1870,9 @@ data AgentErrorType
instance AnyError AgentErrorType where
fromSomeException e = case fromException e of
Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction"
_ -> INTERNAL $ show e
_ -> case fromException e of
Just BlockedIndefinitelyOnMVar -> CRITICAL True "Thread blocked indefinitely on MVar"
_ -> INTERNAL $ show e
{-# INLINE fromSomeException #-}

-- | SMP agent protocol command or response error.
Expand Down
14 changes: 8 additions & 6 deletions src/Simplex/Messaging/Compression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ compress1 bs
| B.length bs <= maxLengthPassthrough = Passthrough bs
| otherwise = Compressed . Large $ Z1.compress compressionLevel bs

decompress1 :: Compressed -> Either String ByteString
decompress1 = \case
decompress1 :: Int -> Compressed -> Either String ByteString
decompress1 limit = \case
Passthrough bs -> Right bs
Compressed (Large bs) -> case Z1.decompress bs of
Z1.Error e -> Left e
Z1.Skip -> Right mempty
Z1.Decompress bs' -> Right bs'
Compressed (Large bs) -> case Z1.decompressedSize bs of
Just sz | sz <= limit -> case Z1.decompress bs of
Z1.Error e -> Left e
Z1.Skip -> Right mempty
Z1.Decompress bs' -> Right bs'
_ -> Left $ "compressed size not specified or exceeds " <> show limit
87 changes: 69 additions & 18 deletions src/Simplex/Messaging/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

module Simplex.Messaging.Util where

import Control.Exception (AllocationLimitExceeded (..), AsyncException (..))
import qualified Control.Exception as E
import Control.Monad
import Control.Monad.Except
Expand All @@ -21,9 +22,9 @@ import Data.Int (Int64)
import Data.List (groupBy, sortOn)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Maybe (listToMaybe)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (listToMaybe)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With, encodeUtf8)
Expand Down Expand Up @@ -93,7 +94,7 @@ anyM :: Monad m => [m Bool] -> m Bool
anyM = foldM (\r a -> if r then pure r else (r ||) <$!> a) False
{-# INLINE anyM #-}

infixl 1 $>>, $>>=
infixl 1 $>>, $>>=

($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b)
f $>>= g = f >>= fmap join . mapM g
Expand All @@ -115,15 +116,19 @@ forME :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) ->
forME = flip mapME
{-# INLINE forME #-}


-- | Monadic version of mapAccumL
-- Copied from ghc-9.6.3 package: https://hackage.haskell.org/package/ghc-9.12.1/docs/GHC-Utils-Monad.html#v:mapAccumLM
-- for backward compatibility with 8.10.7.
mapAccumLM :: (Monad m, Traversable t)
=> (acc -> x -> m (acc, y)) -- ^ combining function
-> acc -- ^ initial state
-> t x -- ^ inputs
-> m (acc, t y) -- ^ final state, outputs
mapAccumLM ::
(Monad m, Traversable t) =>
-- | combining function
(acc -> x -> m (acc, y)) ->
-- | initial state
acc ->
-- | inputs
t x ->
-- | final state, outputs
m (acc, t y)
{-# INLINE [1] mapAccumLM #-}
-- INLINE pragma. mapAccumLM is called in inner loops. Like 'map',
-- we inline it so that we can take advantage of knowing 'f'.
Expand All @@ -132,26 +137,31 @@ mapAccumLM :: (Monad m, Traversable t)
mapAccumLM f s = fmap swap . flip runStateT s . traverse f'
where
f' = StateT . (fmap . fmap) swap . flip f

{-# RULES "mapAccumLM/List" mapAccumLM = mapAccumLM_List #-}
{-# RULES "mapAccumLM/NonEmpty" mapAccumLM = mapAccumLM_NonEmpty #-}

mapAccumLM_List
:: Monad m
=> (acc -> x -> m (acc, y))
-> acc -> [x] -> m (acc, [y])
mapAccumLM_List ::
Monad m =>
(acc -> x -> m (acc, y)) ->
acc ->
[x] ->
m (acc, [y])
{-# INLINE mapAccumLM_List #-}
mapAccumLM_List f = go
where
go s (x : xs) = do
(s1, x') <- f s x
(s1, x') <- f s x
(s2, xs') <- go s1 xs
return (s2, x' : xs')
return (s2, x' : xs')
go s [] = return (s, [])

mapAccumLM_NonEmpty
:: Monad m
=> (acc -> x -> m (acc, y))
-> acc -> NonEmpty x -> m (acc, NonEmpty y)
mapAccumLM_NonEmpty ::
Monad m =>
(acc -> x -> m (acc, y)) ->
acc ->
NonEmpty x ->
m (acc, NonEmpty y)
{-# INLINE mapAccumLM_NonEmpty #-}
mapAccumLM_NonEmpty f s (x :| xs) =
[(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs]
Expand Down Expand Up @@ -197,6 +207,47 @@ allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b ->
allFinally action final = tryAllErrors action >>= \r -> final >> except r
{-# INLINE allFinally #-}

isOwnException :: E.SomeException -> Bool
isOwnException e = case E.fromException e of
Just StackOverflow -> True
Just HeapOverflow -> True
_ -> case E.fromException e of
Just AllocationLimitExceeded -> True
_ -> False
{-# INLINE isOwnException #-}

isAsyncCancellation :: E.SomeException -> Bool
isAsyncCancellation e = case E.fromException e of
Just (_ :: SomeAsyncException) -> not $ isOwnException e
Nothing -> False
{-# INLINE isAsyncCancellation #-}

catchOwn' :: IO a -> (E.SomeException -> IO a) -> IO a
catchOwn' action handleInternal = action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else handleInternal e
{-# INLINE catchOwn' #-}

catchOwn :: MonadUnliftIO m => m a -> (E.SomeException -> m a) -> m a
catchOwn action handleInternal =
withRunInIO $ \run ->
run action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else run (handleInternal e)
{-# INLINE catchOwn #-}

tryAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
tryAllOwnErrors action = ExceptT $ Right <$> runExceptT action `catchOwn` (pure . Left . fromSomeException)
{-# INLINE tryAllOwnErrors #-}

tryAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
tryAllOwnErrors' action = runExceptT action `catchOwn` (pure . Left . fromSomeException)
{-# INLINE tryAllOwnErrors' #-}

catchAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
catchAllOwnErrors action handler = tryAllOwnErrors action >>= either handler pure
{-# INLINE catchAllOwnErrors #-}

catchAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
catchAllOwnErrors' action handler = tryAllOwnErrors' action >>= either handler pure
{-# INLINE catchAllOwnErrors' #-}

eitherToMaybe :: Either a b -> Maybe b
eitherToMaybe = either (const Nothing) Just
{-# INLINE eitherToMaybe #-}
Expand Down
2 changes: 1 addition & 1 deletion tests/AgentTests/MigrationTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ createStore randSuffix migrations confirmMigrations = do
poolSize = 1,
createSchema = True
}
createDBStore dbOpts migrations confirmMigrations
createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing)

cleanup :: Word32 -> IO ()
cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix)
Expand Down
34 changes: 33 additions & 1 deletion tests/CoreTests/UtilTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module CoreTests.UtilTests where

import AgentTests.FunctionalAPITests ()
import Control.Exception (Exception, SomeException, throwIO)
import Control.Exception (AllocationLimitExceeded (..), AsyncException (..), Exception, SomeException, throwIO)
import Control.Monad.Except
import Control.Monad.IO.Class
import Data.IORef
Expand Down Expand Up @@ -71,11 +71,43 @@ utilTests = do
runExceptT (throwTestException `allFinally` final) `shouldReturn` Left (TestException "user error (error)")
it "and should not throw if there are no exceptions" $ withFinal $ \final ->
runExceptT (noErrors `allFinally` final) `shouldReturn` Right "no errors"
describe "tryAllOwnErrors" $ do
it "should return ExceptT error as Left" $
runExceptT (tryAllOwnErrors throwTestError) `shouldReturn` Right (Left (TestError "error"))
it "should return SomeException as Left" $
runExceptT (tryAllOwnErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
it "should catch StackOverflow" $
runExceptT (tryAllOwnErrors $ throwAsync StackOverflow) `shouldReturn` Right (Left (TestException "stack overflow"))
it "should catch HeapOverflow" $
runExceptT (tryAllOwnErrors $ throwAsync HeapOverflow) `shouldReturn` Right (Left (TestException "heap overflow"))
it "should catch AllocationLimitExceeded" $
runExceptT (tryAllOwnErrors $ throwAsync AllocationLimitExceeded) `shouldReturn` Right (Left (TestException "allocation limit exceeded"))
it "should rethrow ThreadKilled" $
runExceptT (tryAllOwnErrors $ throwAsync ThreadKilled) `shouldThrow` (\e -> e == ThreadKilled)
it "should return no errors as Right" $
runExceptT (tryAllOwnErrors noErrors) `shouldReturn` Right (Right "no errors")
describe "catchAllOwnErrors" $ do
it "should catch ExceptT error" $
runExceptT (throwTestError `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\""
it "should catch SomeException" $
runExceptT (throwTestException `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
it "should catch StackOverflow" $
runExceptT (throwAsync StackOverflow `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"stack overflow\""
it "should catch HeapOverflow" $
runExceptT (throwAsync HeapOverflow `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"heap overflow\""
it "should catch AllocationLimitExceeded" $
runExceptT (throwAsync AllocationLimitExceeded `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"allocation limit exceeded\""
it "should rethrow ThreadKilled" $
runExceptT (throwAsync ThreadKilled `catchAllOwnErrors` handleCatch) `shouldThrow` (\e -> e == ThreadKilled)
it "should not throw if there are no errors" $
runExceptT (noErrors `catchAllOwnErrors` throwError) `shouldReturn` Right "no errors"
where
throwTestError :: ExceptT TestError IO String
throwTestError = throwError $ TestError "error"
throwTestException :: ExceptT TestError IO String
throwTestException = liftIO $ throwIO $ userError "error"
throwAsync :: Exception e => e -> ExceptT TestError IO String
throwAsync = liftIO . throwIO
noErrors :: ExceptT TestError IO String
noErrors = pure "no errors"
handleCatch :: TestError -> ExceptT TestError IO String
Expand Down