1414
1515import abc
1616import typing
17+ import uuid
18+
19+ import pandas as pd
20+
21+ import bigframes .constants as constants
22+ import bigframes .dtypes as dtypes
1723
1824DEFAULT_SAMPLING_N = 1000
1925DEFAULT_SAMPLING_STATE = 0
@@ -44,12 +50,13 @@ def _kind(self):
4450
4551 def __init__ (self , data , ** kwargs ) -> None :
4652 self .kwargs = kwargs
47- self .data = self . _compute_plot_data ( data )
53+ self .data = data
4854
4955 def generate (self ) -> None :
50- self .axes = self .data .plot (kind = self ._kind , ** self .kwargs )
56+ plot_data = self ._compute_plot_data ()
57+ self .axes = plot_data .plot (kind = self ._kind , ** self .kwargs )
5158
52- def _compute_plot_data (self , data ):
59+ def _compute_sample_data (self , data ):
5360 # TODO: Cache the sampling data in the PlotAccessor.
5461 sampling_n = self .kwargs .pop ("sampling_n" , DEFAULT_SAMPLING_N )
5562 sampling_random_state = self .kwargs .pop (
@@ -61,6 +68,9 @@ def _compute_plot_data(self, data):
6168 sort = False ,
6269 ).to_pandas ()
6370
71+ def _compute_plot_data (self ):
72+ return self ._compute_sample_data (self .data )
73+
6474
6575class LinePlot (SamplingPlot ):
6676 @property
@@ -78,3 +88,45 @@ class ScatterPlot(SamplingPlot):
7888 @property
7989 def _kind (self ) -> typing .Literal ["scatter" ]:
8090 return "scatter"
91+
92+ def __init__ (self , data , ** kwargs ) -> None :
93+ super ().__init__ (data , ** kwargs )
94+
95+ c = self .kwargs .get ("c" , None )
96+ if self ._is_sequence_arg (c ):
97+ raise NotImplementedError (
98+ f"Only support a single color string or a column name/posision. { constants .FEEDBACK_LINK } "
99+ )
100+
101+ def _compute_plot_data (self ):
102+ sample = self ._compute_sample_data (self .data )
103+
104+ # Works around a pandas bug:
105+ # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a
106+ c = self .kwargs .get ("c" , None )
107+ if pd .core .dtypes .common .is_integer (c ):
108+ c = self .data .columns [c ]
109+ if self ._is_column_name (c , sample ) and sample [c ].dtype == dtypes .STRING_DTYPE :
110+ sample [c ] = sample [c ].astype ("object" )
111+
112+ return sample
113+
114+ def _is_sequence_arg (self , arg ):
115+ return (
116+ arg is not None
117+ and not isinstance (arg , str )
118+ and isinstance (arg , typing .Iterable )
119+ )
120+
121+ def _is_column_name (self , arg , data ):
122+ return (
123+ arg is not None
124+ and pd .core .dtypes .common .is_hashable (arg )
125+ and arg in data .columns
126+ )
127+
128+ def _generate_new_column_name (self , data ):
129+ col_name = None
130+ while col_name is None or col_name in data .columns :
131+ col_name = f"plot_temp_{ str (uuid .uuid4 ())[:8 ]} "
132+ return col_name
0 commit comments