|
17 | 17 | import datetime |
18 | 18 | import functools |
19 | 19 | import io |
20 | | -import itertools |
21 | 20 | import typing |
22 | | -from typing import Iterable, Optional, Sequence |
| 21 | +from typing import Iterable, Optional, Sequence, Tuple |
23 | 22 | import warnings |
24 | 23 |
|
25 | 24 | import google.cloud.bigquery |
@@ -191,19 +190,14 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue: |
191 | 190 | nodes.ConcatNode(children=tuple([self.node, *[val.node for val in other]])) |
192 | 191 | ) |
193 | 192 |
|
194 | | - def project_to_id(self, expression: ex.Expression, output_id: str): |
| 193 | + def compute_values(self, assignments: Sequence[Tuple[ex.Expression, str]]): |
195 | 194 | return ArrayValue( |
196 | | - nodes.ProjectionNode( |
197 | | - child=self.node, |
198 | | - assignments=( |
199 | | - ( |
200 | | - expression, |
201 | | - output_id, |
202 | | - ), |
203 | | - ), |
204 | | - ) |
| 195 | + nodes.ProjectionNode(child=self.node, assignments=tuple(assignments)) |
205 | 196 | ) |
206 | 197 |
|
| 198 | + def project_to_id(self, expression: ex.Expression, output_id: str): |
| 199 | + return self.compute_values(((expression, output_id),)) |
| 200 | + |
207 | 201 | def assign(self, source_id: str, destination_id: str) -> ArrayValue: |
208 | 202 | if destination_id in self.column_ids: # Mutate case |
209 | 203 | exprs = [ |
@@ -341,124 +335,33 @@ def _reproject_to_table(self) -> ArrayValue: |
341 | 335 | ) |
342 | 336 | ) |
343 | 337 |
|
344 | | - def unpivot( |
345 | | - self, |
346 | | - row_labels: typing.Sequence[typing.Hashable], |
347 | | - unpivot_columns: typing.Sequence[ |
348 | | - typing.Tuple[str, typing.Tuple[typing.Optional[str], ...]] |
349 | | - ], |
350 | | - *, |
351 | | - passthrough_columns: typing.Sequence[str] = (), |
352 | | - index_col_ids: typing.Sequence[str] = ["index"], |
353 | | - join_side: typing.Literal["left", "right"] = "left", |
354 | | - ) -> ArrayValue: |
355 | | - """ |
356 | | - Unpivot ArrayValue columns. |
357 | | -
|
358 | | - Args: |
359 | | - row_labels: Identifies the source of the row. Must be equal to length to source column list in unpivot_columns argument. |
360 | | - unpivot_columns: Mapping of column id to list of input column ids. Lists of input columns may use None. |
361 | | - passthrough_columns: Columns that will not be unpivoted. Column id will be preserved. |
362 | | - index_col_id (str): The column id to be used for the row labels. |
363 | | -
|
364 | | - Returns: |
365 | | - ArrayValue: The unpivoted ArrayValue |
366 | | - """ |
367 | | - # There will be N labels, used to disambiguate which of N source columns produced each output row |
368 | | - explode_offsets_id = bigframes.core.guid.generate_guid("unpivot_offsets_") |
369 | | - labels_array = self._create_unpivot_labels_array( |
370 | | - row_labels, index_col_ids, explode_offsets_id |
371 | | - ) |
372 | | - |
373 | | - # Unpivot creates N output rows for each input row, labels disambiguate these N rows |
374 | | - joined_array = self._cross_join_w_labels(labels_array, join_side) |
375 | | - |
376 | | - # Build the output rows as a case statment that selects between the N input columns |
377 | | - unpivot_exprs = [] |
378 | | - # Supports producing multiple stacked ouput columns for stacking only part of hierarchical index |
379 | | - for col_id, input_ids in unpivot_columns: |
380 | | - # row explode offset used to choose the input column |
381 | | - # we use offset instead of label as labels are not necessarily unique |
382 | | - cases = itertools.chain( |
383 | | - *( |
384 | | - ( |
385 | | - ops.eq_op.as_expr(explode_offsets_id, ex.const(i)), |
386 | | - ex.free_var(id_or_null) |
387 | | - if (id_or_null is not None) |
388 | | - else ex.const(None), |
389 | | - ) |
390 | | - for i, id_or_null in enumerate(input_ids) |
391 | | - ) |
392 | | - ) |
393 | | - col_expr = ops.case_when_op.as_expr(*cases) |
394 | | - unpivot_exprs.append((col_expr, col_id)) |
395 | | - |
396 | | - unpivot_col_ids = [id for id, _ in unpivot_columns] |
397 | | - return ArrayValue( |
398 | | - nodes.ProjectionNode( |
399 | | - child=joined_array.node, |
400 | | - assignments=(*unpivot_exprs,), |
401 | | - ) |
402 | | - ).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns]) |
403 | | - |
404 | | - def _cross_join_w_labels( |
405 | | - self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"] |
406 | | - ) -> ArrayValue: |
407 | | - """ |
408 | | - Convert each row in self to N rows, one for each label in labels array. |
409 | | - """ |
410 | | - table_join_side = ( |
411 | | - join_def.JoinSide.LEFT if join_side == "left" else join_def.JoinSide.RIGHT |
412 | | - ) |
413 | | - labels_join_side = table_join_side.inverse() |
414 | | - labels_mappings = tuple( |
415 | | - join_def.JoinColumnMapping(labels_join_side, id, id) |
416 | | - for id in labels_array.schema.names |
417 | | - ) |
418 | | - table_mappings = tuple( |
419 | | - join_def.JoinColumnMapping(table_join_side, id, id) |
420 | | - for id in self.schema.names |
421 | | - ) |
422 | | - join = join_def.JoinDefinition( |
423 | | - conditions=(), mappings=(*labels_mappings, *table_mappings), type="cross" |
424 | | - ) |
425 | | - if join_side == "left": |
426 | | - joined_array = self.relational_join(labels_array, join_def=join) |
427 | | - else: |
428 | | - joined_array = labels_array.relational_join(self, join_def=join) |
429 | | - return joined_array |
430 | | - |
431 | | - def _create_unpivot_labels_array( |
432 | | - self, |
433 | | - former_column_labels: typing.Sequence[typing.Hashable], |
434 | | - col_ids: typing.Sequence[str], |
435 | | - offsets_id: str, |
436 | | - ) -> ArrayValue: |
437 | | - """Create an ArrayValue from a list of label tuples.""" |
438 | | - rows = [] |
439 | | - for row_offset in range(len(former_column_labels)): |
440 | | - row_label = former_column_labels[row_offset] |
441 | | - row_label = (row_label,) if not isinstance(row_label, tuple) else row_label |
442 | | - row = { |
443 | | - col_ids[i]: (row_label[i] if pandas.notnull(row_label[i]) else None) |
444 | | - for i in range(len(col_ids)) |
445 | | - } |
446 | | - row[offsets_id] = row_offset |
447 | | - rows.append(row) |
448 | | - |
449 | | - return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session) |
450 | | - |
451 | 338 | def relational_join( |
452 | 339 | self, |
453 | 340 | other: ArrayValue, |
454 | | - join_def: join_def.JoinDefinition, |
455 | | - ) -> ArrayValue: |
| 341 | + conditions: typing.Tuple[typing.Tuple[str, str], ...] = (), |
| 342 | + type: typing.Literal["inner", "outer", "left", "right", "cross"] = "inner", |
| 343 | + ) -> typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]: |
456 | 344 | join_node = nodes.JoinNode( |
457 | 345 | left_child=self.node, |
458 | 346 | right_child=other.node, |
459 | | - join=join_def, |
| 347 | + conditions=conditions, |
| 348 | + type=type, |
460 | 349 | ) |
461 | | - return ArrayValue(join_node) |
| 350 | + # Maps input ids to output ids for caller convenience |
| 351 | + l_size = len(self.node.schema) |
| 352 | + l_mapping = { |
| 353 | + lcol: ocol |
| 354 | + for lcol, ocol in zip( |
| 355 | + self.node.schema.names, join_node.schema.names[:l_size] |
| 356 | + ) |
| 357 | + } |
| 358 | + r_mapping = { |
| 359 | + rcol: ocol |
| 360 | + for rcol, ocol in zip( |
| 361 | + other.node.schema.names, join_node.schema.names[l_size:] |
| 362 | + ) |
| 363 | + } |
| 364 | + return ArrayValue(join_node), (l_mapping, r_mapping) |
462 | 365 |
|
463 | 366 | def try_align_as_projection( |
464 | 367 | self, |
|
0 commit comments