From d4facde6d478fed59e68ea924ccc710b42db05d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20P=2E=20Ren=C3=A9=20de=20Cotret?= Date: Mon, 3 Mar 2025 20:52:57 -0500 Subject: [PATCH] Fixed an issue where `lead1_`, `lag1_`, `lead_`, and `lag_` did not have the appropriate type --- beam-core/ChangeLog.md | 4 + beam-core/Database/Beam/Query/Extensions.hs | 4 +- beam-postgres/beam-postgres.cabal | 9 +- .../Database/Beam/Postgres/Test/Windowing.hs | 221 ++++++++++++++++++ beam-postgres/test/Main.hs | 14 +- 5 files changed, 240 insertions(+), 12 deletions(-) create mode 100644 beam-postgres/test/Database/Beam/Postgres/Test/Windowing.hs diff --git a/beam-core/ChangeLog.md b/beam-core/ChangeLog.md index a9ef442da..aaf3e8d34 100644 --- a/beam-core/ChangeLog.md +++ b/beam-core/ChangeLog.md @@ -5,6 +5,10 @@ * Added a `Generic` instance to `SqlNull`, `SqlBitString`, and `SqlSerial` (#736). * Added a note to `default_` to specify that it has more restrictions than its type may indicate (#744). +## Bug fixes + +* Fixed an issue where `lead1_`, `lag1_`, `lead_`, and `lag_` did not have the appropriate type, leading to runtime exceptions (#745). + ## Updated dependencies * Updated the upper bound to include `containers-0.8`. diff --git a/beam-core/Database/Beam/Query/Extensions.hs b/beam-core/Database/Beam/Query/Extensions.hs index baee82c17..c35bab8cc 100644 --- a/beam-core/Database/Beam/Query/Extensions.hs +++ b/beam-core/Database/Beam/Query/Extensions.hs @@ -47,13 +47,13 @@ ntile_ (QExpr a) = QExpr (ntileE <$> a) lead1_, lag1_ :: (BeamSqlBackend be, BeamSqlT615Backend be) - => QExpr be s a -> QAgg be s a + => QExpr be s a -> QAgg be s (Maybe a) lead1_ (QExpr a) = QExpr (leadE <$> a <*> pure Nothing <*> pure Nothing) lag1_ (QExpr a) = QExpr (lagE <$> a <*> pure Nothing <*> pure Nothing) lead_, lag_ :: (BeamSqlBackend be, BeamSqlT615Backend be, Integral n) - => QExpr be s a -> QExpr be s n -> QAgg be s a + => QExpr be s a -> QExpr be s n -> QAgg be s (Maybe a) lead_ (QExpr a) (QExpr n) = QExpr (leadE <$> a <*> (Just <$> n) <*> pure Nothing) lag_ (QExpr a) (QExpr n) = QExpr (lagE <$> a <*> (Just <$> n) <*> pure Nothing) diff --git a/beam-postgres/beam-postgres.cabal b/beam-postgres/beam-postgres.cabal index 54ef05a8e..1f1a57509 100644 --- a/beam-postgres/beam-postgres.cabal +++ b/beam-postgres/beam-postgres.cabal @@ -77,11 +77,12 @@ test-suite beam-postgres-tests type: exitcode-stdio-1.0 hs-source-dirs: test main-is: Main.hs - other-modules: Database.Beam.Postgres.Test, - Database.Beam.Postgres.Test.Marshal, - Database.Beam.Postgres.Test.Select, - Database.Beam.Postgres.Test.DataTypes, + other-modules: Database.Beam.Postgres.Test + Database.Beam.Postgres.Test.Marshal + Database.Beam.Postgres.Test.Select + Database.Beam.Postgres.Test.DataTypes Database.Beam.Postgres.Test.Migrate + Database.Beam.Postgres.Test.Windowing build-depends: aeson, base, diff --git a/beam-postgres/test/Database/Beam/Postgres/Test/Windowing.hs b/beam-postgres/test/Database/Beam/Postgres/Test/Windowing.hs new file mode 100644 index 000000000..a41fda556 --- /dev/null +++ b/beam-postgres/test/Database/Beam/Postgres/Test/Windowing.hs @@ -0,0 +1,221 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE StandaloneDeriving #-} + +module Database.Beam.Postgres.Test.Windowing (tests) where + +import Database.Beam +import Database.Beam.Backend.SQL.BeamExtensions +import Database.Beam.Migrate +import Database.Beam.Migrate.Simple (autoMigrate) +import Database.Beam.Postgres +import Database.Beam.Postgres.Migrate (migrationBackend) +import Database.Beam.Postgres.Test + +import Control.Exception (SomeException (..), handle) + +import Data.ByteString (ByteString) +import Data.Int +import Data.Text (Text) + +import Control.Monad (void) +import Test.Tasty +import Test.Tasty.HUnit + +tests :: IO ByteString -> TestTree +tests postgresConn = + testGroup + "Windowing unit tests" + [ testLead1 postgresConn + , testLag1 postgresConn + , testLead postgresConn + , testLag postgresConn + , testLeadWithDefault postgresConn + , testLagWithDefault postgresConn + ] + +testLead1 :: IO ByteString -> TestTree +testLead1 = testCase "lead1_" . windowingQueryTest query expectation + where + query = + withWindow_ + ( \Person{name} -> + frame_ + noPartition_ + (orderPartitionBy_ (asc_ name)) + noBounds_ + ) + ( \Person{name} w -> + (name, lead1_ name `over_` w) + ) + (all_ $ persons db) + expectation = [("Alice", Just "Bob"), ("Bob", Just "Claire"), ("Claire", Nothing)] + +testLag1 :: IO ByteString -> TestTree +testLag1 = testCase "lag1_" . windowingQueryTest query expectation + where + query = + withWindow_ + ( \Person{name} -> + frame_ + noPartition_ + (orderPartitionBy_ (asc_ name)) + noBounds_ + ) + ( \Person{name} w -> + (name, lag1_ name `over_` w) + ) + (all_ $ persons db) + expectation = [("Alice", Nothing), ("Bob", Just "Alice"), ("Claire", Just "Bob")] + +testLead :: IO ByteString -> TestTree +testLead getConnStr = + testGroup + "lead_" + [ testCase "n=1" $ windowingQueryTest (query 1) [("Alice", Just "Bob"), ("Bob", Just "Claire"), ("Claire", Nothing)] getConnStr + , testCase "n=2" $ windowingQueryTest (query 2) [("Alice", Just "Claire"), ("Bob", Nothing), ("Claire", Nothing)] getConnStr + ] + where + query n = + withWindow_ + ( \Person{name} -> + frame_ + noPartition_ + (orderPartitionBy_ (asc_ name)) + noBounds_ + ) + ( \Person{name} w -> + (name, lead_ name (val_ (n :: Int32)) `over_` w) + ) + (all_ $ persons db) + expectation1 = [] + +testLag :: IO ByteString -> TestTree +testLag getConnStr = + testGroup + "lag_" + [ testCase "n=1" $ windowingQueryTest (query 1) [("Alice", Nothing), ("Bob", Just "Alice"), ("Claire", Just "Bob")] getConnStr + , testCase "n=2" $ windowingQueryTest (query 2) [("Alice", Nothing), ("Bob", Nothing), ("Claire", Just "Alice")] getConnStr + ] + where + query n = + withWindow_ + ( \Person{name} -> + frame_ + noPartition_ + (orderPartitionBy_ (asc_ name)) + noBounds_ + ) + ( \Person{name} w -> + (name, lag_ name (val_ (n :: Int32)) `over_` w) + ) + (all_ $ persons db) + expectation = [] + + +testLeadWithDefault :: IO ByteString -> TestTree +testLeadWithDefault getConnStr = + testGroup + "leadWithDefault_" + [ testCase "n=1" $ windowingQueryTest (query 1 "default") [("Alice", "Bob"), ("Bob", "Claire"), ("Claire", "default")] getConnStr + , testCase "n=2" $ windowingQueryTest (query 2 "default") [("Alice", "Claire"), ("Bob", "default"), ("Claire", "default")] getConnStr + ] + where + query n def = + withWindow_ + ( \Person{name} -> + frame_ + noPartition_ + (orderPartitionBy_ (asc_ name)) + noBounds_ + ) + ( \Person{name} w -> + (name, leadWithDefault_ name (val_ (n :: Int32)) (val_ def) `over_` w) + ) + (all_ $ persons db) + expectation1 = [] + + +testLagWithDefault :: IO ByteString -> TestTree +testLagWithDefault getConnStr = + testGroup + "lagWithDefault_" + [ testCase "n=1" $ windowingQueryTest (query 1 "default") [("Alice", "default"), ("Bob", "Alice"), ("Claire", "Bob")] getConnStr + , testCase "n=2" $ windowingQueryTest (query 2 "default") [("Alice", "default"), ("Bob", "default"), ("Claire", "Alice")] getConnStr + ] + where + query n def = + withWindow_ + ( \Person{name} -> + frame_ + noPartition_ + (orderPartitionBy_ (asc_ name)) + noBounds_ + ) + ( \Person{name} w -> + (name, lagWithDefault_ name (val_ (n :: Int32)) (val_ def) `over_` w) + ) + (all_ $ persons db) + expectation = [] + + + +data PersonT f = Person + { name :: C f Text + } + deriving (Generic) + +type Person = PersonT Identity + +type PersonExpr s = PersonT (QExpr Postgres s) + +deriving instance Show Person +deriving instance Eq Person + +instance Beamable PersonT + +instance Table PersonT where + data PrimaryKey PersonT f = PersonKey (C f Text) + deriving stock (Generic) + deriving anyclass (Beamable) + + primaryKey Person{name} = PersonKey name + +data Db f = Db + { persons :: f (TableEntity PersonT) + } + deriving (Generic) + +instance Database Postgres Db + +db :: DatabaseSettings Postgres Db +db = defaultDbSettings + +windowingQueryTest :: + (Eq a, Show a, Eq b, Show b, FromBackendRow Postgres a, FromBackendRow Postgres b) => + Q Postgres Db QBaseScope (QExpr Postgres s a, QExpr Postgres s b) -> + [(a, b)] -> + IO ByteString -> + Assertion +windowingQueryTest query expectation getConnStr = + withTestPostgres "db_windowing_psql" getConnStr $ + \conn -> do + prepareTable conn + results <- + runBeamPostgres conn $ + runSelectReturningList $ + select query + + assertEqual "Unexpected" expectation results + +prepareTable :: Connection -> IO () +prepareTable conn = + runBeamPostgres conn $ do + void $ autoMigrate migrationBackend (defaultMigratableDbSettings @Postgres @Db) + runInsert $ + insert (persons db) $ + insertValues + [ Person "Alice" + , Person "Bob" + , Person "Claire" + ] diff --git a/beam-postgres/test/Main.hs b/beam-postgres/test/Main.hs index 239515ad7..464bbdc6b 100644 --- a/beam-postgres/test/Main.hs +++ b/beam-postgres/test/Main.hs @@ -11,18 +11,20 @@ import qualified Database.Beam.Postgres.Test.Select as Select import qualified Database.Beam.Postgres.Test.Marshal as Marshal import qualified Database.Beam.Postgres.Test.DataTypes as DataType import qualified Database.Beam.Postgres.Test.Migrate as Migrate +import qualified Database.Beam.Postgres.Test.Windowing as Windowing import Database.PostgreSQL.Simple ( ConnectInfo(..), defaultConnectInfo ) import qualified Database.PostgreSQL.Simple as Postgres main :: IO () -main = defaultMain - $ TC.withContainers setupTempPostgresDB - $ \getConnStr -> +main = defaultMain + $ TC.withContainers setupTempPostgresDB + $ \getConnStr -> testGroup "beam-postgres tests" [ Marshal.tests getConnStr , Select.tests getConnStr , DataType.tests getConnStr , Migrate.tests getConnStr + , Windowing.tests getConnStr ] @@ -39,10 +41,10 @@ setupTempPostgresDB = do , ("POSTGRES_DB", db) ] TC.& TC.setWaitingFor (TC.waitForLogLine TC.Stderr ("database system is ready to accept connections" `TL.isInfixOf`)) - - pure $ Postgres.postgreSQLConnectionString + + pure $ Postgres.postgreSQLConnectionString ( defaultConnectInfo { connectHost = "localhost" - , connectUser = unpack user + , connectUser = unpack user , connectPassword = unpack password , connectDatabase = unpack db , connectPort = fromIntegral $ TC.containerPort timescaleContainer 5432