|
23 | 23 | import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler |
24 | 24 |
|
25 | 25 | register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op |
| 26 | +register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op |
26 | 27 |
|
27 | 28 |
|
28 | 29 | @register_unary_op(ops.FloorDtOp, pass_op=True) |
@@ -51,6 +52,28 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression: |
51 | 52 | return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq)) |
52 | 53 |
|
53 | 54 |
|
| 55 | +def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression: |
| 56 | + if origin == "epoch": |
| 57 | + return sge.convert(0) |
| 58 | + elif origin == "start_day": |
| 59 | + return sge.func( |
| 60 | + "UNIX_MICROS", |
| 61 | + sge.Cast( |
| 62 | + this=sge.Cast( |
| 63 | + this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE) |
| 64 | + ), |
| 65 | + to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ), |
| 66 | + ), |
| 67 | + ) |
| 68 | + elif origin == "start": |
| 69 | + return sge.func( |
| 70 | + "UNIX_MICROS", |
| 71 | + sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)), |
| 72 | + ) |
| 73 | + else: |
| 74 | + raise ValueError(f"Origin {origin} not supported") |
| 75 | + |
| 76 | + |
54 | 77 | @register_unary_op(ops.hour_op) |
55 | 78 | def _(expr: TypedExpr) -> sge.Expression: |
56 | 79 | return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr) |
@@ -170,3 +193,243 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression: |
170 | 193 | @register_unary_op(ops.year_op) |
171 | 194 | def _(expr: TypedExpr) -> sge.Expression: |
172 | 195 | return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) |
| 196 | + |
| 197 | + |
| 198 | +def _dtype_to_sql_string(dtype: dtypes.Dtype) -> str: |
| 199 | + if dtype == dtypes.TIMESTAMP_DTYPE: |
| 200 | + return "TIMESTAMP" |
| 201 | + if dtype == dtypes.DATETIME_DTYPE: |
| 202 | + return "DATETIME" |
| 203 | + if dtype == dtypes.DATE_DTYPE: |
| 204 | + return "DATE" |
| 205 | + if dtype == dtypes.TIME_DTYPE: |
| 206 | + return "TIME" |
| 207 | + # Should not be reached in this operator |
| 208 | + raise ValueError(f"Unsupported dtype for datetime conversion: {dtype}") |
| 209 | + |
| 210 | + |
| 211 | +@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True) |
| 212 | +def integer_label_to_datetime_op( |
| 213 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 214 | +) -> sge.Expression: |
| 215 | + # Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined. |
| 216 | + try: |
| 217 | + return _integer_label_to_datetime_op_fixed_frequency(x, y, op) |
| 218 | + except ValueError: |
| 219 | + return _integer_label_to_datetime_op_non_fixed_frequency(x, y, op) |
| 220 | + |
| 221 | + |
| 222 | +def _integer_label_to_datetime_op_fixed_frequency( |
| 223 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 224 | +) -> sge.Expression: |
| 225 | + """ |
| 226 | + This function handles fixed frequency conversions where the unit can range |
| 227 | + from microseconds (us) to days. |
| 228 | + """ |
| 229 | + us = op.freq.nanos / 1000 |
| 230 | + first = _calculate_resample_first(y, op.origin) # type: ignore |
| 231 | + x_label = sge.Cast( |
| 232 | + this=sge.func( |
| 233 | + "TIMESTAMP_MICROS", |
| 234 | + sge.Cast( |
| 235 | + this=sge.Add( |
| 236 | + this=sge.Mul( |
| 237 | + this=sge.Cast(this=x.expr, to=sge.DataType.build("BIGNUMERIC")), |
| 238 | + expression=sge.convert(int(us)), |
| 239 | + ), |
| 240 | + expression=sge.Cast( |
| 241 | + this=first, to=sge.DataType.build("BIGNUMERIC") |
| 242 | + ), |
| 243 | + ), |
| 244 | + to=sge.DataType.build("INT64"), |
| 245 | + ), |
| 246 | + ), |
| 247 | + to=_dtype_to_sql_string(y.dtype), # type: ignore |
| 248 | + ) |
| 249 | + return x_label |
| 250 | + |
| 251 | + |
| 252 | +def _integer_label_to_datetime_op_non_fixed_frequency( |
| 253 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 254 | +) -> sge.Expression: |
| 255 | + """ |
| 256 | + This function handles non-fixed frequency conversions for units ranging |
| 257 | + from weeks to years. |
| 258 | + """ |
| 259 | + rule_code = op.freq.rule_code |
| 260 | + n = op.freq.n |
| 261 | + if rule_code == "W-SUN": # Weekly |
| 262 | + us = n * 7 * 24 * 60 * 60 * 1000000 |
| 263 | + first = sge.func( |
| 264 | + "UNIX_MICROS", |
| 265 | + sge.Add( |
| 266 | + this=sge.TimestampTrunc( |
| 267 | + this=sge.Cast( |
| 268 | + this=y.expr, |
| 269 | + to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ), |
| 270 | + ), |
| 271 | + unit=sge.Var(this="WEEK(MONDAY)"), |
| 272 | + ), |
| 273 | + expression=sge.Interval( |
| 274 | + this=sge.convert(6), unit=sge.Identifier(this="DAY") |
| 275 | + ), |
| 276 | + ), |
| 277 | + ) |
| 278 | + x_label = sge.Cast( |
| 279 | + this=sge.func( |
| 280 | + "TIMESTAMP_MICROS", |
| 281 | + sge.Cast( |
| 282 | + this=sge.Add( |
| 283 | + this=sge.Mul( |
| 284 | + this=sge.Cast( |
| 285 | + this=x.expr, to=sge.DataType.build("BIGNUMERIC") |
| 286 | + ), |
| 287 | + expression=sge.convert(us), |
| 288 | + ), |
| 289 | + expression=sge.Cast( |
| 290 | + this=first, to=sge.DataType.build("BIGNUMERIC") |
| 291 | + ), |
| 292 | + ), |
| 293 | + to=sge.DataType.build("INT64"), |
| 294 | + ), |
| 295 | + ), |
| 296 | + to=_dtype_to_sql_string(y.dtype), # type: ignore |
| 297 | + ) |
| 298 | + elif rule_code == "ME": # Monthly |
| 299 | + one = sge.convert(1) |
| 300 | + twelve = sge.convert(12) |
| 301 | + first = sge.Sub( # type: ignore |
| 302 | + this=sge.Add( |
| 303 | + this=sge.Mul( |
| 304 | + this=sge.Extract(this="YEAR", expression=y.expr), |
| 305 | + expression=twelve, |
| 306 | + ), |
| 307 | + expression=sge.Extract(this="MONTH", expression=y.expr), |
| 308 | + ), |
| 309 | + expression=one, |
| 310 | + ) |
| 311 | + x_val = sge.Add( |
| 312 | + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first |
| 313 | + ) |
| 314 | + year = sge.Cast( |
| 315 | + this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)), |
| 316 | + to=sge.DataType.build("INT64"), |
| 317 | + ) |
| 318 | + month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one) |
| 319 | + next_year = sge.Case( |
| 320 | + ifs=[ |
| 321 | + sge.If( |
| 322 | + this=sge.EQ(this=month, expression=twelve), |
| 323 | + true=sge.Add(this=year, expression=one), |
| 324 | + ) |
| 325 | + ], |
| 326 | + default=year, |
| 327 | + ) |
| 328 | + next_month = sge.Case( |
| 329 | + ifs=[ |
| 330 | + sge.If( |
| 331 | + this=sge.EQ(this=month, expression=twelve), |
| 332 | + true=one, |
| 333 | + ) |
| 334 | + ], |
| 335 | + default=sge.Add(this=month, expression=one), |
| 336 | + ) |
| 337 | + next_month_date = sge.func( |
| 338 | + "TIMESTAMP", |
| 339 | + sge.Anonymous( |
| 340 | + this="DATETIME", |
| 341 | + expressions=[ |
| 342 | + next_year, |
| 343 | + next_month, |
| 344 | + one, |
| 345 | + sge.convert(0), |
| 346 | + sge.convert(0), |
| 347 | + sge.convert(0), |
| 348 | + ], |
| 349 | + ), |
| 350 | + ) |
| 351 | + x_label = sge.Sub( # type: ignore |
| 352 | + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") |
| 353 | + ) |
| 354 | + elif rule_code == "QE-DEC": # Quarterly |
| 355 | + one = sge.convert(1) |
| 356 | + three = sge.convert(3) |
| 357 | + four = sge.convert(4) |
| 358 | + twelve = sge.convert(12) |
| 359 | + first = sge.Sub( # type: ignore |
| 360 | + this=sge.Add( |
| 361 | + this=sge.Mul( |
| 362 | + this=sge.Extract(this="YEAR", expression=y.expr), |
| 363 | + expression=four, |
| 364 | + ), |
| 365 | + expression=sge.Extract(this="QUARTER", expression=y.expr), |
| 366 | + ), |
| 367 | + expression=one, |
| 368 | + ) |
| 369 | + x_val = sge.Add( |
| 370 | + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first |
| 371 | + ) |
| 372 | + year = sge.Cast( |
| 373 | + this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)), |
| 374 | + to=sge.DataType.build("INT64"), |
| 375 | + ) |
| 376 | + month = sge.Mul( # type: ignore |
| 377 | + this=sge.Paren( |
| 378 | + this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one) |
| 379 | + ), |
| 380 | + expression=three, |
| 381 | + ) |
| 382 | + next_year = sge.Case( |
| 383 | + ifs=[ |
| 384 | + sge.If( |
| 385 | + this=sge.EQ(this=month, expression=twelve), |
| 386 | + true=sge.Add(this=year, expression=one), |
| 387 | + ) |
| 388 | + ], |
| 389 | + default=year, |
| 390 | + ) |
| 391 | + next_month = sge.Case( |
| 392 | + ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)], |
| 393 | + default=sge.Add(this=month, expression=one), |
| 394 | + ) |
| 395 | + next_month_date = sge.Anonymous( |
| 396 | + this="DATETIME", |
| 397 | + expressions=[ |
| 398 | + next_year, |
| 399 | + next_month, |
| 400 | + one, |
| 401 | + sge.convert(0), |
| 402 | + sge.convert(0), |
| 403 | + sge.convert(0), |
| 404 | + ], |
| 405 | + ) |
| 406 | + x_label = sge.Sub( # type: ignore |
| 407 | + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") |
| 408 | + ) |
| 409 | + elif rule_code == "YE-DEC": # Yearly |
| 410 | + one = sge.convert(1) |
| 411 | + first = sge.Extract(this="YEAR", expression=y.expr) |
| 412 | + x_val = sge.Add( |
| 413 | + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first |
| 414 | + ) |
| 415 | + next_year = sge.Add(this=x_val, expression=one) # type: ignore |
| 416 | + next_month_date = sge.func( |
| 417 | + "TIMESTAMP", |
| 418 | + sge.Anonymous( |
| 419 | + this="DATETIME", |
| 420 | + expressions=[ |
| 421 | + next_year, |
| 422 | + one, |
| 423 | + one, |
| 424 | + sge.convert(0), |
| 425 | + sge.convert(0), |
| 426 | + sge.convert(0), |
| 427 | + ], |
| 428 | + ), |
| 429 | + ) |
| 430 | + x_label = sge.Sub( # type: ignore |
| 431 | + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") |
| 432 | + ) |
| 433 | + else: |
| 434 | + raise ValueError(rule_code) |
| 435 | + return sge.Cast(this=x_label, to=_dtype_to_sql_string(y.dtype)) # type: ignore |
0 commit comments