diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 982c3099a..3174e05ec 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -180,14 +180,15 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.AgentStore import Simplex.Messaging.Agent.Store.Common (DBStore) import qualified Simplex.Messaging.Agent.Store.DB as DB +import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.Store.Interface (closeDBStore, execSQL, getCurrentMigrations) import Simplex.Messaging.Agent.Store.Shared (UpMigration (..), upMigration) import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, nonBlockingWriteTBQueue, temporaryClientError, unexpectedResponse) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) -import qualified Simplex.Messaging.Crypto.ShortLink as SL import qualified Simplex.Messaging.Crypto.Ratchet as CR +import qualified Simplex.Messaging.Crypto.ShortLink as SL import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..), NtfTokenId, PNMessageData (..), pnMessagesP) @@ -217,7 +218,6 @@ import Simplex.Messaging.Protocol ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) -import Simplex.Messaging.Agent.Store.Entity import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.Util @@ -833,8 +833,9 @@ newConn :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> B newConn c nm userId enableNtfs cMode userData_ clientData pqInitKeys subMode = do srv <- getSMPServer c userId connId <- newConnNoQueues c userId enableNtfs cMode (CR.connPQEncryption pqInitKeys) - (connId,) <$> newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKeys subMode srv - `catchE` \e -> withStore' c (`deleteConnRecord` connId) >> throwE e + (connId,) + <$> newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKeys subMode srv + `catchE` \e -> withStore' c (`deleteConnRecord` connId) >> throwE e setConnShortLink' :: AgentClient -> NetworkRequestMode -> ConnId -> SConnectionMode c -> UserLinkData -> Maybe CRClientData -> AM (ConnShortLink c) setConnShortLink' c nm connId cMode userData clientData = @@ -914,8 +915,7 @@ getConnShortLink' c nm userId = \case decryptData srv linkKey k (sndId, d) = do r@(cReq, clData) <- liftEither $ SL.decryptLinkData @c linkKey k d let (srv', sndId') = qAddress (connReqQueue cReq) - unless (srv `sameSrvHost` srv' && sndId == sndId') $ - throwE $ AGENT $ A_LINK "different address" + unless (srv `sameSrvHost` srv' && sndId == sndId') $ throwE $ AGENT $ A_LINK "different address" pure $ if srv' == srv then r else (updateConnReqServer srv cReq, clData) sameSrvHost ProtocolServer {host = h :| _} ProtocolServer {host = hs} = h `elem` hs updateConnReqServer :: SMPServer -> ConnectionRequestUri c -> ConnectionRequestUri c @@ -1004,7 +1004,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey connReqWithShortLink :: SMPQueueUri -> ConnectionRequestUri c -> SMPQueueUri -> Maybe ShortLinkCreds -> AM (CreatedConnLink c) connReqWithShortLink qUri cReq qUri' shortLink = case shortLink of Just ShortLinkCreds {shortLinkId, shortLinkKey} - | qUri == qUri' -> pure $ case cReq of + | qUri == qUri' -> pure $ case cReq of CRContactUri _ -> CCLink cReq $ Just $ CSLContact SLSServer CCTContact srv shortLinkKey CRInvitationUri crData (CR.E2ERatchetParamsUri vr k1 k2 _) -> let cReq' = case pqInitKeys of @@ -1682,7 +1682,7 @@ enqueueMessageB c reqs = do storeSentMsg db cfg aMessageIds = \case Left e -> pure (aMessageIds, Left e) Right req@(csqs_, pqEnc_, msgFlags, mbr) -> case mbr of - VRValue i_ aMessage -> case i_ >>= (`IM.lookup` aMessageIds) of + VRValue i_ aMessage -> case i_ >>= (`IM.lookup` aMessageIds) of Just _ -> pure (aMessageIds, Left $ INTERNAL "enqueueMessageB: storeSentMsg duplicate saved message body") Nothing -> do (mbId_, r) <- case csqs_ of @@ -1724,7 +1724,6 @@ enqueueMessageB c reqs = do handleInternal :: E.SomeException -> IO (Either AgentErrorType b) handleInternal = pure . Left . INTERNAL . show - encodeAgentMsgStr :: AMessage -> InternalSndId -> PrevSndMsgHash -> ByteString encodeAgentMsgStr aMessage internalSndId prevMsgHash = do let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash @@ -2536,10 +2535,13 @@ getNextSMPServer c userId = getNextServer c userId storageSrvs {-# INLINE getNextSMPServer #-} subscriber :: AgentClient -> AM' () -subscriber c@AgentClient {msgQ} = forever $ do +subscriber c@AgentClient {msgQ, subQ} = run $ forever $ do t <- atomically $ readTBQueue msgQ agentOperationBracket c AORcvNetwork waitUntilActive $ processSMPTransmissions c t + where + run a = a `catchOwn` \e -> notify $ CRITICAL True $ "Agent subscriber stopped: " <> show e + notify err = atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR err) cleanupManager :: AgentClient -> AM' () cleanupManager c@AgentClient {subQ} = do @@ -2848,7 +2850,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId ackDel :: InternalId -> AM ACKd ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd handleNotifyAck :: AM ACKd -> AM ACKd - handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack + handleNotifyAck m = m `catchAllOwnErrors` \e -> notify (ERR e) >> ack SMP.END -> atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False)) >>= notifyEnd diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index d4d302df7..df04d7c12 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -173,7 +173,7 @@ module Simplex.Messaging.Agent.Protocol where import Control.Applicative (optional, (<|>)) -import Control.Exception (BlockedIndefinitelyOnSTM (..), fromException) +import Control.Exception (BlockedIndefinitelyOnMVar (..), BlockedIndefinitelyOnSTM (..), fromException) import Data.Aeson (FromJSON (..), ToJSON (..), Value (..), (.:), (.:?)) import qualified Data.Aeson as J' import qualified Data.Aeson.Encoding as JE @@ -1870,7 +1870,9 @@ data AgentErrorType instance AnyError AgentErrorType where fromSomeException e = case fromException e of Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction" - _ -> INTERNAL $ show e + _ -> case fromException e of + Just BlockedIndefinitelyOnMVar -> CRITICAL True "Thread blocked indefinitely on MVar" + _ -> INTERNAL $ show e {-# INLINE fromSomeException #-} -- | SMP agent protocol command or response error. diff --git a/src/Simplex/Messaging/Compression.hs b/src/Simplex/Messaging/Compression.hs index 19d91a300..18efe1a0d 100644 --- a/src/Simplex/Messaging/Compression.hs +++ b/src/Simplex/Messaging/Compression.hs @@ -36,10 +36,12 @@ compress1 bs | B.length bs <= maxLengthPassthrough = Passthrough bs | otherwise = Compressed . Large $ Z1.compress compressionLevel bs -decompress1 :: Compressed -> Either String ByteString -decompress1 = \case +decompress1 :: Int -> Compressed -> Either String ByteString +decompress1 limit = \case Passthrough bs -> Right bs - Compressed (Large bs) -> case Z1.decompress bs of - Z1.Error e -> Left e - Z1.Skip -> Right mempty - Z1.Decompress bs' -> Right bs' + Compressed (Large bs) -> case Z1.decompressedSize bs of + Just sz | sz <= limit -> case Z1.decompress bs of + Z1.Error e -> Left e + Z1.Skip -> Right mempty + Z1.Decompress bs' -> Right bs' + _ -> Left $ "compressed size not specified or exceeds " <> show limit diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index f93119b3c..780fc0b1c 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -4,6 +4,7 @@ module Simplex.Messaging.Util where +import Control.Exception (AllocationLimitExceeded (..), AsyncException (..)) import qualified Control.Exception as E import Control.Monad import Control.Monad.Except @@ -21,9 +22,9 @@ import Data.Int (Int64) import Data.List (groupBy, sortOn) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Maybe (listToMaybe) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Data.Maybe (listToMaybe) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8With, encodeUtf8) @@ -93,7 +94,7 @@ anyM :: Monad m => [m Bool] -> m Bool anyM = foldM (\r a -> if r then pure r else (r ||) <$!> a) False {-# INLINE anyM #-} -infixl 1 $>>, $>>= +infixl 1 $>>, $>>= ($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b) f $>>= g = f >>= fmap join . mapM g @@ -115,15 +116,19 @@ forME :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) -> forME = flip mapME {-# INLINE forME #-} - -- | Monadic version of mapAccumL -- Copied from ghc-9.6.3 package: https://hackage.haskell.org/package/ghc-9.12.1/docs/GHC-Utils-Monad.html#v:mapAccumLM -- for backward compatibility with 8.10.7. -mapAccumLM :: (Monad m, Traversable t) - => (acc -> x -> m (acc, y)) -- ^ combining function - -> acc -- ^ initial state - -> t x -- ^ inputs - -> m (acc, t y) -- ^ final state, outputs +mapAccumLM :: + (Monad m, Traversable t) => + -- | combining function + (acc -> x -> m (acc, y)) -> + -- | initial state + acc -> + -- | inputs + t x -> + -- | final state, outputs + m (acc, t y) {-# INLINE [1] mapAccumLM #-} -- INLINE pragma. mapAccumLM is called in inner loops. Like 'map', -- we inline it so that we can take advantage of knowing 'f'. @@ -132,26 +137,31 @@ mapAccumLM :: (Monad m, Traversable t) mapAccumLM f s = fmap swap . flip runStateT s . traverse f' where f' = StateT . (fmap . fmap) swap . flip f + {-# RULES "mapAccumLM/List" mapAccumLM = mapAccumLM_List #-} {-# RULES "mapAccumLM/NonEmpty" mapAccumLM = mapAccumLM_NonEmpty #-} -mapAccumLM_List - :: Monad m - => (acc -> x -> m (acc, y)) - -> acc -> [x] -> m (acc, [y]) +mapAccumLM_List :: + Monad m => + (acc -> x -> m (acc, y)) -> + acc -> + [x] -> + m (acc, [y]) {-# INLINE mapAccumLM_List #-} mapAccumLM_List f = go where go s (x : xs) = do - (s1, x') <- f s x + (s1, x') <- f s x (s2, xs') <- go s1 xs - return (s2, x' : xs') + return (s2, x' : xs') go s [] = return (s, []) -mapAccumLM_NonEmpty - :: Monad m - => (acc -> x -> m (acc, y)) - -> acc -> NonEmpty x -> m (acc, NonEmpty y) +mapAccumLM_NonEmpty :: + Monad m => + (acc -> x -> m (acc, y)) -> + acc -> + NonEmpty x -> + m (acc, NonEmpty y) {-# INLINE mapAccumLM_NonEmpty #-} mapAccumLM_NonEmpty f s (x :| xs) = [(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs] @@ -197,6 +207,47 @@ allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> allFinally action final = tryAllErrors action >>= \r -> final >> except r {-# INLINE allFinally #-} +isOwnException :: E.SomeException -> Bool +isOwnException e = case E.fromException e of + Just StackOverflow -> True + Just HeapOverflow -> True + _ -> case E.fromException e of + Just AllocationLimitExceeded -> True + _ -> False +{-# INLINE isOwnException #-} + +isAsyncCancellation :: E.SomeException -> Bool +isAsyncCancellation e = case E.fromException e of + Just (_ :: SomeAsyncException) -> not $ isOwnException e + Nothing -> False +{-# INLINE isAsyncCancellation #-} + +catchOwn' :: IO a -> (E.SomeException -> IO a) -> IO a +catchOwn' action handleInternal = action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else handleInternal e +{-# INLINE catchOwn' #-} + +catchOwn :: MonadUnliftIO m => m a -> (E.SomeException -> m a) -> m a +catchOwn action handleInternal = + withRunInIO $ \run -> + run action `E.catch` \e -> if isAsyncCancellation e then E.throwIO e else run (handleInternal e) +{-# INLINE catchOwn #-} + +tryAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a) +tryAllOwnErrors action = ExceptT $ Right <$> runExceptT action `catchOwn` (pure . Left . fromSomeException) +{-# INLINE tryAllOwnErrors #-} + +tryAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a) +tryAllOwnErrors' action = runExceptT action `catchOwn` (pure . Left . fromSomeException) +{-# INLINE tryAllOwnErrors' #-} + +catchAllOwnErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a +catchAllOwnErrors action handler = tryAllOwnErrors action >>= either handler pure +{-# INLINE catchAllOwnErrors #-} + +catchAllOwnErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a +catchAllOwnErrors' action handler = tryAllOwnErrors' action >>= either handler pure +{-# INLINE catchAllOwnErrors' #-} + eitherToMaybe :: Either a b -> Maybe b eitherToMaybe = either (const Nothing) Just {-# INLINE eitherToMaybe #-} diff --git a/tests/AgentTests/MigrationTests.hs b/tests/AgentTests/MigrationTests.hs index e4de45c7a..8245cfd51 100644 --- a/tests/AgentTests/MigrationTests.hs +++ b/tests/AgentTests/MigrationTests.hs @@ -212,7 +212,7 @@ createStore randSuffix migrations confirmMigrations = do poolSize = 1, createSchema = True } - createDBStore dbOpts migrations confirmMigrations + createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing) cleanup :: Word32 -> IO () cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix) diff --git a/tests/CoreTests/UtilTests.hs b/tests/CoreTests/UtilTests.hs index 946902358..580f4e9b0 100644 --- a/tests/CoreTests/UtilTests.hs +++ b/tests/CoreTests/UtilTests.hs @@ -3,7 +3,7 @@ module CoreTests.UtilTests where import AgentTests.FunctionalAPITests () -import Control.Exception (Exception, SomeException, throwIO) +import Control.Exception (AllocationLimitExceeded (..), AsyncException (..), Exception, SomeException, throwIO) import Control.Monad.Except import Control.Monad.IO.Class import Data.IORef @@ -71,11 +71,43 @@ utilTests = do runExceptT (throwTestException `allFinally` final) `shouldReturn` Left (TestException "user error (error)") it "and should not throw if there are no exceptions" $ withFinal $ \final -> runExceptT (noErrors `allFinally` final) `shouldReturn` Right "no errors" + describe "tryAllOwnErrors" $ do + it "should return ExceptT error as Left" $ + runExceptT (tryAllOwnErrors throwTestError) `shouldReturn` Right (Left (TestError "error")) + it "should return SomeException as Left" $ + runExceptT (tryAllOwnErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) + it "should catch StackOverflow" $ + runExceptT (tryAllOwnErrors $ throwAsync StackOverflow) `shouldReturn` Right (Left (TestException "stack overflow")) + it "should catch HeapOverflow" $ + runExceptT (tryAllOwnErrors $ throwAsync HeapOverflow) `shouldReturn` Right (Left (TestException "heap overflow")) + it "should catch AllocationLimitExceeded" $ + runExceptT (tryAllOwnErrors $ throwAsync AllocationLimitExceeded) `shouldReturn` Right (Left (TestException "allocation limit exceeded")) + it "should rethrow ThreadKilled" $ + runExceptT (tryAllOwnErrors $ throwAsync ThreadKilled) `shouldThrow` (\e -> e == ThreadKilled) + it "should return no errors as Right" $ + runExceptT (tryAllOwnErrors noErrors) `shouldReturn` Right (Right "no errors") + describe "catchAllOwnErrors" $ do + it "should catch ExceptT error" $ + runExceptT (throwTestError `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\"" + it "should catch SomeException" $ + runExceptT (throwTestException `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\"" + it "should catch StackOverflow" $ + runExceptT (throwAsync StackOverflow `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"stack overflow\"" + it "should catch HeapOverflow" $ + runExceptT (throwAsync HeapOverflow `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"heap overflow\"" + it "should catch AllocationLimitExceeded" $ + runExceptT (throwAsync AllocationLimitExceeded `catchAllOwnErrors` handleCatch) `shouldReturn` Right "caught TestException \"allocation limit exceeded\"" + it "should rethrow ThreadKilled" $ + runExceptT (throwAsync ThreadKilled `catchAllOwnErrors` handleCatch) `shouldThrow` (\e -> e == ThreadKilled) + it "should not throw if there are no errors" $ + runExceptT (noErrors `catchAllOwnErrors` throwError) `shouldReturn` Right "no errors" where throwTestError :: ExceptT TestError IO String throwTestError = throwError $ TestError "error" throwTestException :: ExceptT TestError IO String throwTestException = liftIO $ throwIO $ userError "error" + throwAsync :: Exception e => e -> ExceptT TestError IO String + throwAsync = liftIO . throwIO noErrors :: ExceptT TestError IO String noErrors = pure "no errors" handleCatch :: TestError -> ExceptT TestError IO String