88from matplotlib import gridspec
99from matplotlib import pyplot as plt
1010
11+ from .exceptions import ShapeInconsistencyError
12+
1113mpl .rcParams ['pdf.fonttype' ] = 42
1214mpl .rcParams ["font.sans-serif" ] = "Arial"
1315
@@ -40,15 +42,15 @@ def __init__(self, df_size: pd.DataFrame,
4042 'figure'
4143 ]
4244 if df_color is not None and df_size .shape != df_color .shape :
43- raise ValueError ('df_size and df_color should have the same dimension' )
45+ raise ShapeInconsistencyError ('df_size and df_color should have the same dimension' )
4446 if df_circle is not None and df_size .shape != df_circle .shape :
45- raise ValueError ('df_size and df_circle should have the same dimension' )
47+ raise ShapeInconsistencyError ('df_size and df_circle should have the same dimension' )
4648 if row_colors is not None and df_size .shape [0 ] != len (row_colors ):
47- raise ValueError ('row_colors has the wrong shape' )
49+ raise ShapeInconsistencyError ('row_colors has the wrong shape' )
4850 if col_colors is not None and df_size .shape [1 ] != len (col_colors ):
49- raise ValueError ('col_colors has the wrong shape' )
51+ raise ShapeInconsistencyError ('col_colors has the wrong shape' )
5052 if mask_frames is not None and df_size .shape != mask_frames .shape :
51- raise ValueError ('df_size and mask_frames should have the same dimension' )
53+ raise ShapeInconsistencyError ('df_size and mask_frames should have the same dimension' )
5254
5355 self .size_data = df_size
5456 self .color_data = df_color
@@ -108,13 +110,15 @@ def __get_figure(self):
108110 ax_abandon .axis ('off' )
109111 return ax , gs_cbar_legend , gs_sizes_legend , gs_circles_legend , ax_row_bands , ax_col_bands , fig
110112
111- # TODO update with the newest version of __init__
112113 @classmethod
113114 def parse_from_tidy_data (cls , data_frame : pd .DataFrame , item_key : str , group_key : str , sizes_key : str ,
114115 color_key : Union [None , str ] = None , circle_key : Union [None , str ] = None ,
115116 selected_item : Union [None , Sequence ] = None ,
116- selected_group : Union [None , Sequence ] = None , * ,
117- sizes_func : Union [None , Callable ] = None , color_func : Union [None , Callable ] = None
117+ selected_group : Union [None , Sequence ] = None ,
118+ row_colors : Union [None , pd .Series , pd .DataFrame ] = None ,
119+ col_colors : Union [None , pd .Series , pd .DataFrame ] = None ,
120+ mask_frames : Union [None , pd .DataFrame , Sequence [Union [str , int ]]] = None ,
121+ * , sizes_func : Union [None , Callable ] = None , color_func : Union [None , Callable ] = None
118122 ):
119123 """
120124
@@ -125,12 +129,15 @@ class method for conveniently constructing DotPlot from tidy data
125129 :param group_key:
126130 :param sizes_key:
127131 :param color_key:
132+ :param circle_key:
128133 :param selected_item: default None, if specified, this should be subsets of `item_key` in `data_frame`
129134 alternatively, this param can be used as self-defined item order definition.
130135 :param selected_group: Same as `selected_item`, for group order and subset groups
131- :param sizes_func:
132- :param color_func:
133- :param circle_key:
136+ :param col_colors:
137+ :param row_colors:
138+ :param mask_frames:
139+ :param sizes_func: Callable
140+ :param color_func: Callable
134141 :return:
135142 """
136143 keys = [v for v in [item_key , group_key , sizes_key , color_key , circle_key ] if v is not None ]
@@ -159,7 +166,20 @@ class method for conveniently constructing DotPlot from tidy data
159166 color_df = data_frame .loc [:, data_frame .columns .str .startswith (color_key )]
160167 if circle_key is not None :
161168 circle_df = data_frame .loc [:, data_frame .columns .str .startswith (circle_key )]
162- return cls (sizes_df , color_df , circle_df )
169+ if (mask_frames is not None ) and isinstance (mask_frames , Sequence ):
170+ mask_frames = mask_frames if isinstance (mask_frames , List ) else list (mask_frames )
171+ n_row , n_col = sizes_df .shape
172+ if len (mask_frames ) == n_row :
173+ mask_frames = pd .DataFrame ([[item ] * n_col for item in mask_frames ],
174+ index = sizes_df .index .values , columns = sizes_df .columns .values )
175+ elif len (mask_frames ) == n_col :
176+ mask_frames = pd .DataFrame ([[item ] * n_row for item in mask_frames ],
177+ index = sizes_df .columns .values , columns = sizes_df .index .values )
178+ mask_frames = mask_frames .T
179+ else :
180+ raise ShapeInconsistencyError ('mask frame shape Error.' )
181+ return cls (sizes_df , color_df , circle_df , row_colors = row_colors ,
182+ col_colors = col_colors , mask_frames = mask_frames )
163183
164184 def __get_coordinates (self ):
165185 X = list (range (1 , self .width_item + 1 )) * self .height_item
0 commit comments