@@ -833,35 +833,46 @@ def _materialize_local(
833833 return df , execute_result .query_job
834834
835835 def _downsample (
836- self , total_rows : int , sampling_method : str , fraction : float , random_state
836+ self ,
837+ total_rows : int ,
838+ sampling_method : str ,
839+ fraction : float ,
840+ random_state : Optional [int ],
837841 ) -> Block :
838842 # either selecting fraction or number of rows
839843 if sampling_method == _HEAD :
840844 filtered_block = self .slice (stop = int (total_rows * fraction ))
841845 return filtered_block
842846 elif (sampling_method == _UNIFORM ) and (random_state is None ):
843- filtered_expr = self .expr ._uniform_sampling (fraction )
844- block = Block (
845- filtered_expr ,
846- index_columns = self .index_columns ,
847- column_labels = self .column_labels ,
848- index_labels = self .index .names ,
849- )
850- return block
847+ return self .sample (fraction = fraction , shuffle = False , seed = random_state )
851848 elif sampling_method == _UNIFORM :
852- block = self .split (
853- fracs = (fraction ,),
854- random_state = random_state ,
855- sort = False ,
856- )[0 ]
857- return block
849+ return self .sample (fraction = fraction , shuffle = False )
858850 else :
859851 # This part should never be called, just in case.
860852 raise NotImplementedError (
861853 f"The downsampling method { sampling_method } is not implemented, "
862854 f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
863855 )
864856
857+ def sample (
858+ self , fraction : float , shuffle : bool , seed : Optional [int ] = None
859+ ) -> Block :
860+ assert fraction <= 1.0 and fraction >= 0
861+ return Block (
862+ self .expr ._uniform_sampling (fraction = fraction , shuffle = shuffle , seed = seed ),
863+ index_columns = self .index_columns ,
864+ column_labels = self .column_labels ,
865+ index_labels = self .index .names ,
866+ )
867+
868+ def shuffle (self , seed : Optional [int ] = None ) -> Block :
869+ return Block (
870+ self .expr ._uniform_sampling (fraction = 1.0 , shuffle = True , seed = seed ),
871+ index_columns = self .index_columns ,
872+ column_labels = self .column_labels ,
873+ index_labels = self .index .names ,
874+ )
875+
865876 def split (
866877 self ,
867878 ns : Iterable [int ] = (),
@@ -894,22 +905,11 @@ def split(
894905 random_state = random .randint (- (2 ** 63 ), 2 ** 63 - 1 )
895906
896907 # Create a new column with random_state value.
897- block , random_state_col = block .create_constant (str (random_state ))
908+ og_ordering_col = None
909+ if sort is False :
910+ block , og_ordering_col = block .promote_offsets ()
898911
899- # Create an ordering col and convert to string
900- block , ordering_col = block .promote_offsets ()
901- block , string_ordering_col = block .apply_unary_op (
902- ordering_col , ops .AsTypeOp (to_type = bigframes .dtypes .STRING_DTYPE )
903- )
904-
905- # Apply hash method to sum col and order by it.
906- block , string_sum_col = block .apply_binary_op (
907- string_ordering_col , random_state_col , ops .strconcat_op
908- )
909- block , hash_string_sum_col = block .apply_unary_op (string_sum_col , ops .hash_op )
910- block = block .order_by (
911- [ordering .OrderingExpression (ex .deref (hash_string_sum_col ))]
912- )
912+ block = block .shuffle (seed = random_state )
913913
914914 intervals = []
915915 cur = 0
@@ -934,21 +934,15 @@ def split(
934934 for sliced_block in sliced_blocks
935935 ]
936936 elif sort is False :
937+ assert og_ordering_col is not None
937938 sliced_blocks = [
938939 sliced_block .order_by (
939- [ordering .OrderingExpression (ex .deref (ordering_col ))]
940- )
940+ [ordering .OrderingExpression (ex .deref (og_ordering_col ))]
941+ ). drop_columns ([ og_ordering_col ])
941942 for sliced_block in sliced_blocks
942943 ]
943944
944- drop_cols = [
945- random_state_col ,
946- ordering_col ,
947- string_ordering_col ,
948- string_sum_col ,
949- hash_string_sum_col ,
950- ]
951- return [sliced_block .drop_columns (drop_cols ) for sliced_block in sliced_blocks ]
945+ return [sliced_block for sliced_block in sliced_blocks ]
952946
953947 def _compute_dry_run (
954948 self ,
0 commit comments