diff --git a/CHANGELOG.md b/CHANGELOG.md index dda77c01a..c1fce0e92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,12 +57,16 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp ### 💥 Breaking Changes ### ♻️ Changed +- **Template integration**: Plotly templates now fully control plot styling without hardcoded overrides +- **Dataset first plotting**: Underlying plotting methods in plotting.py now use `xr.Dataset` as the main datatype, converting to it if they get a DataFrame passed ### 🗑️ Deprecated ### 🔥 Removed ### 🐛 Fixed +- Improved error messages for matplotlib with multidimensional data +- Better dimension validation in `plot_heatmap()` ### 🔒 Security diff --git a/examples/02_Complex/complex_example_results.py b/examples/02_Complex/complex_example_results.py index 5020f71fe..56251fd99 100644 --- a/examples/02_Complex/complex_example_results.py +++ b/examples/02_Complex/complex_example_results.py @@ -25,8 +25,9 @@ # --- Detailed Plots --- # In depth plot for individual flow rates ('__' is used as the delimiter between Component and Flow results.plot_heatmap('Wärmelast(Q_th_Last)|flow_rate') - for flow_rate in results['BHKW2'].inputs + results['BHKW2'].outputs: - results.plot_heatmap(flow_rate) + for bus in results.buses.values(): + bus.plot_node_balance_pie() + bus.plot_node_balance() # --- Plotting internal variables manually --- results.plot_heatmap('BHKW2(Q_th)|on') diff --git a/examples/03_Calculation_types/example_calculation_types.py b/examples/03_Calculation_types/example_calculation_types.py index 8df8e742f..f71a8eed7 100644 --- a/examples/03_Calculation_types/example_calculation_types.py +++ b/examples/03_Calculation_types/example_calculation_types.py @@ -202,35 +202,35 @@ def get_solutions(calcs: list, variable: str) -> xr.Dataset: # --- Plotting for comparison --- fx.plotting.with_plotly( - get_solutions(calculations, 'Speicher|charge_state').to_dataframe(), + get_solutions(calculations, 'Speicher|charge_state'), mode='line', title='Charge State Comparison', ylabel='Charge state', ).write_html('results/Charge State.html') fx.plotting.with_plotly( - get_solutions(calculations, 'BHKW2(Q_th)|flow_rate').to_dataframe(), + get_solutions(calculations, 'BHKW2(Q_th)|flow_rate'), mode='line', title='BHKW2(Q_th) Flow Rate Comparison', ylabel='Flow rate', ).write_html('results/BHKW2 Thermal Power.html') fx.plotting.with_plotly( - get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe(), + get_solutions(calculations, 'costs(temporal)|per_timestep'), mode='line', title='Operation Cost Comparison', ylabel='Costs [€]', ).write_html('results/Operation Costs.html') fx.plotting.with_plotly( - pd.DataFrame(get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe().sum()).T, + get_solutions(calculations, 'costs(temporal)|per_timestep').sum('time'), mode='stacked_bar', title='Total Cost Comparison', ylabel='Costs [€]', ).update_layout(barmode='group').write_html('results/Total Costs.html') fx.plotting.with_plotly( - pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]), + pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]).to_xarray(), mode='stacked_bar', ).update_layout(title='Duration Comparison', xaxis_title='Calculation type', yaxis_title='Time (s)').write_html( 'results/Speed Comparison.html' diff --git a/flixopt/aggregation.py b/flixopt/aggregation.py index 91ef618a9..cd22c2ad7 100644 --- a/flixopt/aggregation.py +++ b/flixopt/aggregation.py @@ -150,10 +150,12 @@ def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path df_agg = self.aggregated_data.copy().rename( columns={col: f'Aggregated - {col}' for col in self.aggregated_data.columns} ) - fig = plotting.with_plotly(df_org, 'line', colors=colormap) + fig = plotting.with_plotly(df_org.to_xarray(), 'line', colors=colormap, xlabel='Time in h') for trace in fig.data: trace.update(dict(line=dict(dash='dash'))) - fig = plotting.with_plotly(df_agg, 'line', colors=colormap, fig=fig) + fig2 = plotting.with_plotly(df_agg.to_xarray(), 'line', colors=colormap, xlabel='Time in h') + for trace in fig2.data: + fig.add_trace(trace) fig.update_layout( title='Original vs Aggregated Data (original = ---)', xaxis_title='Index', yaxis_title='Value' diff --git a/flixopt/plotting.py b/flixopt/plotting.py index bd1f3c2c4..6fe5de0e4 100644 --- a/flixopt/plotting.py +++ b/flixopt/plotting.py @@ -42,6 +42,8 @@ import xarray as xr from plotly.exceptions import PlotlyError +from .config import CONFIG + if TYPE_CHECKING: import pyvis @@ -326,19 +328,84 @@ def process_colors( return color_list +def _ensure_dataset(data: xr.Dataset | pd.DataFrame) -> xr.Dataset: + """Convert DataFrame to Dataset if needed.""" + if isinstance(data, xr.Dataset): + return data + elif isinstance(data, pd.DataFrame): + # Convert DataFrame to Dataset + return data.to_xarray() + else: + raise TypeError(f'Data must be xr.Dataset or pd.DataFrame, got {type(data).__name__}') + + +def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None: + """Validate dataset for plotting (checks for empty data, non-numeric types, etc.).""" + # Check for empty data + if not allow_empty and len(data.data_vars) == 0: + raise ValueError('Empty Dataset provided (no variables). Cannot create plot.') + + # Check if dataset has any data (xarray uses nbytes for total size) + if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True: + if not allow_empty and len(data.data_vars) > 0: + raise ValueError('Dataset has zero size. Cannot create plot.') + if len(data.data_vars) == 0: + return # Empty dataset, nothing to validate + return + + # Check for non-numeric data types + for var in data.data_vars: + dtype = data[var].dtype + if not np.issubdtype(dtype, np.number): + raise TypeError( + f"Variable '{var}' has non-numeric dtype '{dtype}'. " + f'Plotting requires numeric data types (int, float, etc.).' + ) + + # Warn about NaN/Inf values + for var in data.data_vars: + if data[var].isnull().any(): + logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.") + if np.isinf(data[var].values).any(): + logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.") + + +def resolve_colors( + data: xr.Dataset, + colors: ColorType, + engine: PlottingEngine = 'plotly', +) -> dict[str, str]: + """Resolve colors parameter to a dict mapping variable names to colors.""" + # Get variable names from Dataset (always strings and unique) + labels = list(data.data_vars.keys()) + + # If explicit dict provided, use it directly + if isinstance(colors, dict): + return colors + + # If string or list, use ColorProcessor (traditional behavior) + if isinstance(colors, (str, list)): + processor = ColorProcessor(engine=engine) + return processor.process_colors(colors, labels, return_mapping=True) + + raise TypeError(f'Wrong type passed to resolve_colors(): {type(colors)}') + + def with_plotly( - data: pd.DataFrame | xr.DataArray | xr.Dataset, + data: xr.Dataset | pd.DataFrame, mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar', colors: ColorType = 'viridis', title: str = '', ylabel: str = '', - xlabel: str = 'Time in h', - fig: go.Figure | None = None, + xlabel: str = '', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, shared_yaxes: bool = True, shared_xaxes: bool = True, + trace_kwargs: dict[str, Any] | None = None, + layout_kwargs: dict[str, Any] | None = None, + **px_kwargs: Any, ) -> go.Figure: """ Plot data with Plotly using facets (subplots) and/or animation for multidimensional data. @@ -347,7 +414,7 @@ def with_plotly( For simple plots without faceting, can optionally add to an existing figure. Args: - data: A DataFrame or xarray DataArray/Dataset to plot. + data: An xarray Dataset to plot. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines, 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts. colors: Color specification (colormap, list, or dict mapping labels to colors). @@ -364,6 +431,12 @@ def with_plotly( facet_cols: Number of columns in the facet grid (used when facet_by is single dimension). shared_yaxes: Whether subplots share y-axes. shared_xaxes: Whether subplots share x-axes. + trace_kwargs: Optional dict of parameters to pass to fig.update_traces(). + Use this to customize trace properties (e.g., marker style, line width). + layout_kwargs: Optional dict of parameters to pass to fig.update_layout(). + Use this to customize layout properties (e.g., width, height, legend position). + **px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function + (px.bar, px.line, px.area). These override default arguments if provided. Returns: A Plotly figure object containing the faceted/animated plot. @@ -372,85 +445,85 @@ def with_plotly( Simple plot: ```python - fig = with_plotly(df, mode='area', title='Energy Mix') + fig = with_plotly(dataset, mode='area', title='Energy Mix') ``` Facet by scenario: ```python - fig = with_plotly(ds, facet_by='scenario', facet_cols=2) + fig = with_plotly(dataset, facet_by='scenario', facet_cols=2) ``` Animate by period: ```python - fig = with_plotly(ds, animate_by='period') + fig = with_plotly(dataset, animate_by='period') ``` Facet and animate: ```python - fig = with_plotly(ds, facet_by='scenario', animate_by='period') + fig = with_plotly(dataset, facet_by='scenario', animate_by='period') ``` """ if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}") + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + # Handle empty data - if isinstance(data, pd.DataFrame) and data.empty: - return go.Figure() - elif isinstance(data, xr.DataArray) and data.size == 0: - return go.Figure() - elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0: + if len(data.data_vars) == 0: + logger.error('"with_plotly() got an empty Dataset.') return go.Figure() - # Warn if fig parameter is used with faceting - if fig is not None and (facet_by is not None or animate_by is not None): - logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.') - fig = None - - # Convert xarray to long-form DataFrame for Plotly Express - if isinstance(data, (xr.DataArray, xr.Dataset)): - # Convert to long-form (tidy) DataFrame - # Structure: time, variable, value, scenario, period, ... (all dims as columns) - if isinstance(data, xr.Dataset): - # Stack all data variables into long format - df_long = data.to_dataframe().reset_index() - # Melt to get: time, scenario, period, ..., variable, value - id_vars = [dim for dim in data.dims] - value_vars = list(data.data_vars) - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') - else: - # DataArray - df_long = data.to_dataframe().reset_index() - if data.name: - df_long = df_long.rename(columns={data.name: 'value'}) - else: - # Unnamed DataArray, find the value column - non_dim_cols = [col for col in df_long.columns if col not in data.dims] - if len(non_dim_cols) != 1: - raise ValueError( - f'Expected exactly one non-dimension column for unnamed DataArray, ' - f'but found {len(non_dim_cols)}: {non_dim_cols}' + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create a simple DataFrame with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='plotly') + marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables] + + # Create simple plot based on mode using go (not px) for better color control + if mode in ('stacked_bar', 'grouped_bar'): + fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)]) + elif mode == 'line': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + mode='lines+markers', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), ) - value_col = non_dim_cols[0] - df_long = df_long.rename(columns={value_col: 'value'}) - df_long['variable'] = data.name or 'data' - else: - # Already a DataFrame - convert to long format for Plotly Express - df_long = data.reset_index() - if 'time' not in df_long.columns: - # First column is probably time - df_long = df_long.rename(columns={df_long.columns[0]: 'time'}) - # Melt to long format - id_vars = [ - col - for col in df_long.columns - if col in ['time', 'scenario', 'period'] - or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else []) - ] - value_vars = [col for col in df_long.columns if col not in id_vars] - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') + ] + ) + elif mode == 'area': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + fill='tozeroy', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), + ) + ] + ) + + fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False) + return fig + + # Convert Dataset to long-form DataFrame for Plotly Express + # Structure: time, variable, value, scenario, period, ... (all dims as columns) + dim_names = list(data.dims) + df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value') # Validate facet_by and animate_by dimensions exist in the data available_dims = [col for col in df_long.columns if col not in ['variable', 'value']] @@ -505,10 +578,32 @@ def with_plotly( processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars) color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)} + # Determine which dimension to use for x-axis + # Collect dimensions used for faceting and animation + used_dims = set() + if facet_row: + used_dims.add(facet_row) + if facet_col: + used_dims.add(facet_col) + if animate_by: + used_dims.add(animate_by) + + # Find available dimensions for x-axis (not used for faceting/animation) + x_candidates = [d for d in available_dims if d not in used_dims] + + # Use 'time' if available, otherwise use the first available dimension + if 'time' in x_candidates: + x_dim = 'time' + elif len(x_candidates) > 0: + x_dim = x_candidates[0] + else: + # Fallback: use the first dimension (shouldn't happen in normal cases) + x_dim = available_dims[0] if available_dims else 'time' + # Create plot using Plotly Express based on mode common_args = { 'data_frame': df_long, - 'x': 'time', + 'x': x_dim, 'y': 'value', 'color': 'variable', 'facet_row': facet_row, @@ -516,7 +611,7 @@ def with_plotly( 'animation_frame': animate_by, 'color_discrete_map': color_discrete_map, 'title': title, - 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''}, + 'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''}, } # Add facet_col_wrap for single facet dimension @@ -577,50 +672,48 @@ def with_plotly( if hasattr(trace, 'fill'): trace.fill = None - # Update layout with basic styling (Plotly Express handles sizing automatically) - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) - # Update axes to share if requested (Plotly Express already handles this, but we can customize) if not shared_yaxes: fig.update_yaxes(matches=None) if not shared_xaxes: fig.update_xaxes(matches=None) + # Apply user-provided trace and layout customizations + if trace_kwargs: + fig.update_traces(**trace_kwargs) + if layout_kwargs: + fig.update_layout(**layout_kwargs) + return fig def with_matplotlib( - data: pd.DataFrame, + data: xr.Dataset | pd.DataFrame, mode: Literal['stacked_bar', 'line'] = 'stacked_bar', colors: ColorType = 'viridis', title: str = '', ylabel: str = '', xlabel: str = 'Time in h', figsize: tuple[int, int] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, + plot_kwargs: dict[str, Any] | None = None, ) -> tuple[plt.Figure, plt.Axes]: """ - Plot a DataFrame with Matplotlib using stacked bars or stepped lines. + Plot data with Matplotlib using stacked bars or stepped lines. Args: - data: A DataFrame containing the data to plot. The index should represent time (e.g., hours), - and each column represents a separate data series. + data: An xarray Dataset to plot. After conversion to DataFrame, + the index represents time and each column represents a separate data series (variables). mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + colors: Color specification. Can be: + - A colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'}) title: The title of the plot. ylabel: The ylabel of the plot. xlabel: The xlabel of the plot. - figsize: Specify the size of the figure - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. + figsize: Specify the size of the figure (width, height) in inches. + plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls. + Use this to customize plot properties (e.g., linewidth, alpha, edgecolor). Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -633,45 +726,100 @@ def with_matplotlib( if mode not in ('stacked_bar', 'line'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}") - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + # Create new figure and axes + fig, ax = plt.subplots(figsize=figsize) + + # Initialize plot_kwargs if not provided + if plot_kwargs is None: + plot_kwargs = {} + + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create simple bar/line plot with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + colors_list = [color_discrete_map.get(var, '#808080') for var in variables] + + # Create plot based on mode + if mode == 'stacked_bar': + ax.bar(variables, values, color=colors_list, **plot_kwargs) + elif mode == 'line': + ax.plot( + variables, + values, + marker='o', + color=colors_list[0] if len(set(colors_list)) == 1 else None, + **plot_kwargs, + ) + # If different colors, plot each point separately + if len(set(colors_list)) > 1: + ax.clear() + for i, (var, val) in enumerate(zip(variables, values, strict=False)): + ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs) + ax.set_xticks(range(len(variables))) + ax.set_xticklabels(variables) + + ax.set_xlabel(xlabel, ha='center') + ax.set_ylabel(ylabel, va='center') + ax.set_title(title) + ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y') + fig.tight_layout() + + return fig, ax - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns)) + # Resolve colors first (includes validation) + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + + # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form) + df = data.to_dataframe() + + # Get colors in column order + processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns] if mode == 'stacked_bar': - cumulative_positive = np.zeros(len(data)) - cumulative_negative = np.zeros(len(data)) - width = data.index.to_series().diff().dropna().min() # Minimum time difference + cumulative_positive = np.zeros(len(df)) + cumulative_negative = np.zeros(len(df)) + width = df.index.to_series().diff().dropna().min() # Minimum time difference - for i, column in enumerate(data.columns): - positive_values = np.clip(data[column], 0, None) # Keep only positive values - negative_values = np.clip(data[column], None, 0) # Keep only negative values + for i, column in enumerate(df.columns): + positive_values = np.clip(df[column], 0, None) # Keep only positive values + negative_values = np.clip(df[column], None, 0) # Keep only negative values # Plot positive bars ax.bar( - data.index, + df.index, positive_values, bottom=cumulative_positive, color=processed_colors[i], label=column, width=width, align='center', + **plot_kwargs, ) cumulative_positive += positive_values.values # Plot negative bars ax.bar( - data.index, + df.index, negative_values, bottom=cumulative_negative, color=processed_colors[i], label='', # No label for negative bars width=width, align='center', + **plot_kwargs, ) cumulative_negative += negative_values.values elif mode == 'line': - for i, column in enumerate(data.columns): - ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column) + for i, column in enumerate(df.columns): + ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs) # Aesthetics ax.set_xlabel(xlabel, ha='center') @@ -944,228 +1092,88 @@ def plot_network( ) -def pie_with_plotly( - data: pd.DataFrame, - colors: ColorType = 'viridis', - title: str = '', - legend_title: str = '', - hole: float = 0.0, - fig: go.Figure | None = None, -) -> go.Figure: - """ - Create a pie chart with Plotly to visualize the proportion of values in a DataFrame. - - Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. - colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) - title: The title of the plot. - legend_title: The title for the legend. - hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). - fig: A Plotly figure object to plot on. If not provided, a new figure will be created. - - Returns: - A Plotly figure object containing the generated pie chart. - - Notes: - - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. - +def preprocess_dataset_for_pie(data: xr.Dataset, lower_percentage_threshold: float = 5.0) -> xr.Dataset: """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') - return go.Figure() - - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() - - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() - - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] - - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() - - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels) - - # Create figure if not provided - fig = fig if fig is not None else go.Figure() - - # Add pie trace - fig.add_trace( - go.Pie( - labels=labels, - values=values, - hole=hole, - marker=dict(colors=processed_colors), - textinfo='percent+label+value', - textposition='inside', - insidetextorientation='radial', - ) - ) - - # Update layout for better aesthetics - fig.update_layout( - title=title, - legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), # Increase font size for better readability - ) - - return fig + Preprocess data for pie chart display. - -def pie_with_matplotlib( - data: pd.DataFrame, - colors: ColorType = 'viridis', - title: str = '', - legend_title: str = 'Categories', - hole: float = 0.0, - figsize: tuple[int, int] = (10, 8), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, -) -> tuple[plt.Figure, plt.Axes]: - """ - Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame. + Groups items that are individually below the threshold percentage into an "Other" category. + Works with xarray Datasets by summing all dimensions for each variable. Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) - title: The title of the plot. - legend_title: The title for the legend. - hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). - figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. + data: Input data (xarray Dataset, DataFrame, or Series) + lower_percentage_threshold: Percentage threshold - items below this are grouped into "Other" Returns: - A tuple containing the Matplotlib figure and axes objects used for the plot. - - Notes: - - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. - + Processed xarray Dataset with small items grouped into "Other" """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) - return fig, ax - - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() - - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() + dataset = _ensure_dataset(data) + _validate_plotting_data(dataset, allow_empty=True) + + # Sum all dimensions for each variable to get total values + values = {} + for var in data.data_vars: + var_data = data[var] + # Sum across all dimensions to get total + if len(var_data.dims) > 0: + total_value = float(var_data.sum().item()) + else: + total_value = float(var_data.item()) - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] + # Handle negative values + if total_value < 0: + print(f'Warning: Negative value for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() + # Only keep positive values + if total_value > 0: + values[var] = total_value - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, labels) + if not values or lower_percentage_threshold <= 0: + return data - # Create figure and axis if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Calculate total and percentages + total = sum(values.values()) + percentages = {name: (val / total) * 100 for name, val in values.items()} - # Draw the pie chart - wedges, texts, autotexts = ax.pie( - values, - labels=labels, - colors=processed_colors, - autopct='%1.1f%%', - startangle=90, - shadow=False, - wedgeprops=dict(width=0.5) if hole > 0 else None, # Set width for donut - ) + # Find items below threshold + below_threshold = {name: val for name, val in values.items() if percentages[name] < lower_percentage_threshold} + above_threshold = {name: val for name, val in values.items() if percentages[name] >= lower_percentage_threshold} - # Adjust the wedgeprops to make donut hole size consistent with plotly - # For matplotlib, the hole size is determined by the wedge width - # Convert hole parameter to wedge width - if hole > 0: - # Adjust hole size to match plotly's hole parameter - # In matplotlib, wedge width is relative to the radius (which is 1) - # For plotly, hole is a fraction of the radius - wedge_width = 1 - hole - for wedge in wedges: - wedge.set_width(wedge_width) - - # Customize the appearance - # Make autopct text more visible - for autotext in autotexts: - autotext.set_fontsize(10) - autotext.set_color('white') - - # Set aspect ratio to be equal to ensure a circular pie - ax.set_aspect('equal') - - # Add title - if title: - ax.set_title(title, fontsize=16) + # Only group if there are at least 2 items below threshold + if len(below_threshold) > 1: + # Sum up the small items + other_sum = sum(below_threshold.values()) - # Create a legend if there are many segments - if len(labels) > 6: - ax.legend(wedges, labels, title=legend_title, loc='center left', bbox_to_anchor=(1, 0, 0.5, 1)) + # Create new dataset with items above threshold + "Other" + result_dict = above_threshold.copy() + result_dict['Other'] = other_sum - # Apply tight layout - fig.tight_layout() + # Convert back to Dataset + return xr.Dataset({name: xr.DataArray(val) for name, val in result_dict.items()}) - return fig, ax + return data def dual_pie_with_plotly( - data_left: pd.Series, - data_right: pd.Series, + data_left: xr.Dataset | pd.DataFrame, + data_right: xr.Dataset | pd.DataFrame, colors: ColorType = 'viridis', title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', hole: float = 0.2, lower_percentage_group: float = 5.0, - hover_template: str = '%{label}: %{value} (%{percent})', text_info: str = 'percent+label', text_position: str = 'inside', + hover_template: str = '%{label}: %{value} (%{percent})', ) -> go.Figure: """ - Create two pie charts side by side with Plotly, with consistent coloring across both charts. + Create two pie charts side by side with Plotly. Args: - data_left: Series for the left pie chart. - data_right: Series for the right pie chart. - colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) + data_left: Dataset for the left pie chart. Variables are summed across all dimensions. + data_right: Dataset for the right pie chart. Variables are summed across all dimensions. + colors: Color specification (colorscale name, list of colors, or dict mapping) title: The main title of the plot. subtitles: Tuple containing the subtitles for (left, right) charts. legend_title: The title for the legend. @@ -1177,119 +1185,74 @@ def dual_pie_with_plotly( text_position: Position of text: 'inside', 'outside', 'auto', or 'none'. Returns: - A Plotly figure object containing the generated dual pie chart. + Plotly Figure object """ - from plotly.subplots import make_subplots - - # Check for empty data - if data_left.empty and data_right.empty: - logger.error('Both datasets are empty. Returning empty figure.') - return go.Figure() - - # Create a subplot figure - fig = make_subplots( - rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05 - ) - - # Process series to handle negative values and apply minimum percentage threshold - def preprocess_series(series: pd.Series): - """ - Preprocess a series for pie chart display by handling negative values - and grouping the smallest parts together if they collectively represent - less than the specified percentage threshold. - - Args: - series: The series to preprocess - - Returns: - A preprocessed pandas Series - """ - # Handle negative values - if (series < 0).any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - series = series.abs() - - # Remove zeros - series = series[series > 0] - - # Apply minimum percentage threshold if needed - if lower_percentage_group and not series.empty: - total = series.sum() - if total > 0: - # Sort series by value (ascending) - sorted_series = series.sort_values() - - # Calculate cumulative percentage contribution - cumulative_percent = (sorted_series.cumsum() / total) * 100 - - # Find entries that collectively make up less than lower_percentage_group - to_group = cumulative_percent <= lower_percentage_group - - if to_group.sum() > 1: - # Create "Other" category for the smallest values that together are < threshold - other_sum = sorted_series[to_group].sum() - - # Keep only values that aren't in the "Other" group - result_series = series[~series.index.isin(sorted_series[to_group].index)] - - # Add the "Other" category if it has a value - if other_sum > 0: - result_series['Other'] = other_sum - - return result_series - - return series - - data_left_processed = preprocess_series(data_left) - data_right_processed = preprocess_series(data_right) - - # Get unique set of all labels for consistent coloring - all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index)) + # Preprocess data (converts to Dataset and groups small items) + left_processed = preprocess_dataset_for_pie(data_left, lower_percentage_group) + right_processed = preprocess_dataset_for_pie(data_right, lower_percentage_group) + + # Extract labels and values from Datasets + def extract_from_dataset(ds): + labels = [] + values = [] + for var in ds.data_vars: + var_data = ds[var] + if len(var_data.dims) > 0: + val = float(var_data.sum().item()) + else: + val = float(var_data.item()) + labels.append(str(var)) + values.append(val) + return labels, values - # Get consistent color mapping for both charts using our unified function - color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True) + left_labels, left_values = extract_from_dataset(left_processed) + right_labels, right_values = extract_from_dataset(right_processed) - # Function to create a pie trace with consistently mapped colors - def create_pie_trace(data_series, side): - if data_series.empty: - return None + # Get all unique labels for consistent coloring + all_labels = sorted(set(left_labels) | set(right_labels)) - labels = data_series.index.tolist() - values = data_series.values.tolist() - trace_colors = [color_map[label] for label in labels] + # Create color map + color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True) - return go.Pie( - labels=labels, - values=values, - name=side, - marker=dict(colors=trace_colors), - hole=hole, - textinfo=text_info, - textposition=text_position, - insidetextorientation='radial', - hovertemplate=hover_template, - sort=True, # Sort values by default (largest first) + # Create figure + fig = go.Figure() + + # Add left pie + if left_labels: + fig.add_trace( + go.Pie( + labels=left_labels, + values=left_values, + name=subtitles[0], + marker=dict(colors=[color_map.get(label, '#636EFA') for label in left_labels]), + hole=hole, + textinfo=text_info, + textposition=text_position, + hovertemplate=hover_template, + domain=dict(x=[0, 0.48]), + ) ) - # Add left pie if data exists - left_trace = create_pie_trace(data_left_processed, subtitles[0]) - if left_trace: - left_trace.domain = dict(x=[0, 0.48]) - fig.add_trace(left_trace, row=1, col=1) - - # Add right pie if data exists - right_trace = create_pie_trace(data_right_processed, subtitles[1]) - if right_trace: - right_trace.domain = dict(x=[0.52, 1]) - fig.add_trace(right_trace, row=1, col=2) + # Add right pie + if right_labels: + fig.add_trace( + go.Pie( + labels=right_labels, + values=right_values, + name=subtitles[1], + marker=dict(colors=[color_map.get(label, '#636EFA') for label in right_labels]), + hole=hole, + textinfo=text_info, + textposition=text_position, + hovertemplate=hover_template, + domain=dict(x=[0.52, 1]), + ) + ) # Update layout fig.update_layout( title=title, legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), margin=dict(t=80, b=50, l=30, r=30), ) @@ -1297,8 +1260,8 @@ def create_pie_trace(data_series, side): def dual_pie_with_matplotlib( - data_left: pd.Series, - data_right: pd.Series, + data_left: xr.Dataset | pd.DataFrame | pd.Series, + data_right: xr.Dataset | pd.DataFrame | pd.Series, colors: ColorType = 'viridis', title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), @@ -1306,154 +1269,109 @@ def dual_pie_with_matplotlib( hole: float = 0.2, lower_percentage_group: float = 5.0, figsize: tuple[int, int] = (14, 7), - fig: plt.Figure | None = None, - axes: list[plt.Axes] | None = None, ) -> tuple[plt.Figure, list[plt.Axes]]: """ - Create two pie charts side by side with Matplotlib, with consistent coloring across both charts. - Leverages the existing pie_with_matplotlib function. + Create two pie charts side by side with Matplotlib. Args: data_left: Series for the left pie chart. data_right: Series for the right pie chart. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) + colors: Color specification (colormap name, list of colors, or dict mapping) title: The main title of the plot. subtitles: Tuple containing the subtitles for (left, right) charts. legend_title: The title for the legend. hole: Size of the hole in the center for creating donut charts (0.0 to 1.0). lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created. Returns: - A tuple containing the Matplotlib figure and list of axes objects used for the plot. + Tuple of (Figure, list of Axes) """ - # Check for empty data - if data_left.empty and data_right.empty: - logger.error('Both datasets are empty. Returning empty figure.') - if fig is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) - return fig, axes - - # Create figure and axes if not provided - if fig is None or axes is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) - - # Process series to handle negative values and apply minimum percentage threshold - def preprocess_series(series: pd.Series): - """ - Preprocess a series for pie chart display by handling negative values - and grouping the smallest parts together if they collectively represent - less than the specified percentage threshold. - """ - # Handle negative values - if (series < 0).any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - series = series.abs() - - # Remove zeros - series = series[series > 0] - - # Apply minimum percentage threshold if needed - if lower_percentage_group and not series.empty: - total = series.sum() - if total > 0: - # Sort series by value (ascending) - sorted_series = series.sort_values() - - # Calculate cumulative percentage contribution - cumulative_percent = (sorted_series.cumsum() / total) * 100 - - # Find entries that collectively make up less than lower_percentage_group - to_group = cumulative_percent <= lower_percentage_group - - if to_group.sum() > 1: - # Create "Other" category for the smallest values that together are < threshold - other_sum = sorted_series[to_group].sum() - - # Keep only values that aren't in the "Other" group - result_series = series[~series.index.isin(sorted_series[to_group].index)] + # Preprocess data (converts to Dataset and groups small items) + left_processed = preprocess_dataset_for_pie(data_left, lower_percentage_group) + right_processed = preprocess_dataset_for_pie(data_right, lower_percentage_group) + + # Extract labels and values from Datasets + def extract_from_dataset(ds): + labels = [] + values = [] + for var in ds.data_vars: + var_data = ds[var] + if len(var_data.dims) > 0: + val = float(var_data.sum().item()) + else: + val = float(var_data.item()) + labels.append(str(var)) + values.append(val) + return labels, values - # Add the "Other" category if it has a value - if other_sum > 0: - result_series['Other'] = other_sum + left_labels, left_values = extract_from_dataset(left_processed) + right_labels, right_values = extract_from_dataset(right_processed) - return result_series + # Get all unique labels for consistent coloring + all_labels = sorted(set(left_labels) | set(right_labels)) - return series + # Create color map + color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True) - # Preprocess data - data_left_processed = preprocess_series(data_left) - data_right_processed = preprocess_series(data_right) + # Create figure + fig, axes = plt.subplots(1, 2, figsize=figsize) - # Convert Series to DataFrames for pie_with_matplotlib - df_left = pd.DataFrame(data_left_processed).T if not data_left_processed.empty else pd.DataFrame() - df_right = pd.DataFrame(data_right_processed).T if not data_right_processed.empty else pd.DataFrame() + def draw_pie(ax, labels, values, subtitle): + """Draw a single pie chart.""" + if not labels: + ax.set_title(subtitle) + ax.axis('off') + return - # Get unique set of all labels for consistent coloring - all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index)) + chart_colors = [color_map[label] for label in labels] - # Get consistent color mapping for both charts using our unified function - color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True) + # Draw pie + wedges, texts, autotexts = ax.pie( + values, + labels=labels, + colors=chart_colors, + autopct='%1.1f%%', + startangle=90, + wedgeprops=dict(width=1 - hole) if hole > 0 else None, + ) - # Configure colors for each DataFrame based on the consistent mapping - left_colors = [color_map[col] for col in df_left.columns] if not df_left.empty else [] - right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else [] + # Style text + for autotext in autotexts: + autotext.set_fontsize(10) + autotext.set_color('white') + autotext.set_weight('bold') - # Create left pie chart - if not df_left.empty: - pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0]) - else: - axes[0].set_title(subtitles[0]) - axes[0].axis('off') + ax.set_aspect('equal') + ax.set_title(subtitle, fontsize=14, pad=20) - # Create right pie chart - if not df_right.empty: - pie_with_matplotlib(data=df_right, colors=right_colors, title=subtitles[1], hole=hole, fig=fig, ax=axes[1]) - else: - axes[1].set_title(subtitles[1]) - axes[1].axis('off') + # Draw both pies + draw_pie(axes[0], left_labels, left_values, subtitles[0]) + draw_pie(axes[1], right_labels, right_values, subtitles[1]) # Add main title if title: fig.suptitle(title, fontsize=16, y=0.98) - # Adjust layout - fig.tight_layout() - - # Create a unified legend if both charts have data - if not df_left.empty and not df_right.empty: - # Remove individual legends - for ax in axes: - if ax.get_legend(): - ax.get_legend().remove() - - # Create handles for the unified legend - handles = [] - labels_for_legend = [] - - for label in all_labels: - color = color_map[label] - patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label) - handles.append(patch) - labels_for_legend.append(label) + # Create unified legend + if left_labels or right_labels: + handles = [ + plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[label], markersize=10) + for label in all_labels + ] - # Add unified legend fig.legend( handles=handles, - labels=labels_for_legend, + labels=all_labels, title=legend_title, loc='lower center', - bbox_to_anchor=(0.5, 0), - ncol=min(len(all_labels), 5), # Limit columns to 5 for readability + bbox_to_anchor=(0.5, -0.02), + ncol=min(len(all_labels), 5), ) - # Add padding at the bottom for the legend - fig.subplots_adjust(bottom=0.2) + fig.subplots_adjust(bottom=0.15) + + fig.tight_layout() return fig, axes @@ -1469,6 +1387,7 @@ def heatmap_with_plotly( | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + **imshow_kwargs: Any, ) -> go.Figure: """ Plot a heatmap visualization using Plotly's imshow with faceting and animation support. @@ -1501,6 +1420,11 @@ def heatmap_with_plotly( - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping (will error if only 1D time data) fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + **imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow. + Common options include: + - aspect: 'auto', 'equal', or a number for aspect ratio + - zmin, zmax: Minimum and maximum values for color scale + - labels: Dict to customize axis labels Returns: A Plotly figure object containing the heatmap visualization. @@ -1589,12 +1513,26 @@ def heatmap_with_plotly( heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] if len(heatmap_dims) < 2: - # Need at least 2 dimensions for a heatmap - logger.error( - f'Heatmap requires at least 2 dimensions for rows and columns. ' - f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' - ) - return go.Figure() + # Handle single-dimension case by adding variable name as a dimension + if len(heatmap_dims) == 1: + # Get the variable name, or use a default + var_name = data.name if data.name else 'value' + + # Expand the DataArray by adding a new dimension with the variable name + data = data.expand_dims({'variable': [var_name]}) + + # Update available dimensions + available_dims = list(data.dims) + heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] + + logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}') + else: + # No dimensions at all - cannot create a heatmap + logger.error( + f'Heatmap requires at least 1 dimension. ' + f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' + ) + return go.Figure() # Setup faceting parameters for Plotly Express # Note: px.imshow only supports facet_col, not facet_row @@ -1631,23 +1569,21 @@ def heatmap_with_plotly( if animate_by: common_args['animation_frame'] = animate_by + # Merge in additional imshow kwargs + common_args.update(imshow_kwargs) + try: fig = px.imshow(**common_args) except Exception as e: logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.') # Fallback: create a simple heatmap without faceting - fig = px.imshow( - data.values, - color_continuous_scale=colors if isinstance(colors, str) else 'viridis', - title=title, - ) - - # Update layout with basic styling - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) + fallback_args = { + 'img': data.values, + 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis', + 'title': title, + } + fallback_args.update(imshow_kwargs) + fig = px.imshow(**fallback_args) return fig @@ -1657,12 +1593,15 @@ def heatmap_with_matplotlib( colors: ColorType = 'viridis', title: str = '', figsize: tuple[float, float] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + vmin: float | None = None, + vmax: float | None = None, + imshow_kwargs: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes]: """ Plot a heatmap visualization using Matplotlib's imshow. @@ -1674,16 +1613,25 @@ def heatmap_with_matplotlib( data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions. If more than 2 dimensions exist, additional dimensions will be reduced by taking the first slice. - colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu'). + colors: Color specification. Should be a colormap name (e.g., 'turbo', 'RdBu'). title: The title of the heatmap. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. reshape_time: Time reshaping configuration: - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + vmin: Minimum value for color scale. If None, uses data minimum. + vmax: Maximum value for color scale. If None, uses data maximum. + imshow_kwargs: Optional dict of parameters to pass to ax.imshow(). + Use this to customize image properties (e.g., interpolation, aspect). + cbar_kwargs: Optional dict of parameters to pass to plt.colorbar(). + Use this to customize colorbar properties (e.g., orientation, label). + **kwargs: Additional keyword arguments passed to ax.imshow(). + Common options include: + - interpolation: 'nearest', 'bilinear', 'bicubic', etc. + - alpha: Transparency level (0-1) + - extent: [left, right, bottom, top] for axis limits Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -1705,19 +1653,33 @@ def heatmap_with_matplotlib( fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h')) ``` """ + # Initialize kwargs if not provided + if imshow_kwargs is None: + imshow_kwargs = {} + if cbar_kwargs is None: + cbar_kwargs = {} + + # Merge any additional kwargs into imshow_kwargs + # This allows users to pass imshow options directly + imshow_kwargs.update(kwargs) + # Handle empty data if data.size == 0: - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) return fig, ax # Apply time reshaping using the new unified function # Matplotlib doesn't support faceting/animation, so we pass None for those data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill) - # Create figure and axes if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Handle single-dimension case by adding variable name as a dimension + if isinstance(data, xr.DataArray) and len(data.dims) == 1: + var_name = data.name if data.name else 'value' + data = data.expand_dims({'variable': [var_name]}) + logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}') + + # Create figure and axes + fig, ax = plt.subplots(figsize=figsize) # Extract data values # If data has more than 2 dimensions, we need to reduce it @@ -1745,12 +1707,19 @@ def heatmap_with_matplotlib( # Process colormap cmap = colors if isinstance(colors, str) else 'viridis' - # Create the heatmap using imshow - im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper') + # Create the heatmap using imshow with user customizations + imshow_defaults = {'cmap': cmap, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax} + imshow_defaults.update(imshow_kwargs) # User kwargs override defaults + im = ax.imshow(values, **imshow_defaults) + + # Add colorbar with user customizations + cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05} + cbar_defaults.update(cbar_kwargs) # User kwargs override defaults + cbar = plt.colorbar(im, **cbar_defaults) - # Add colorbar - cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05) - cbar.set_label('Value') + # Set colorbar label if not overridden by user + if 'label' not in cbar_kwargs: + cbar.set_label('Value') # Set labels and title ax.set_xlabel(str(x_labels).capitalize()) @@ -1770,6 +1739,7 @@ def export_figure( user_path: pathlib.Path | None = None, show: bool = True, save: bool = False, + dpi: int = 300, ) -> go.Figure | tuple[plt.Figure, plt.Axes]: """ Export a figure to a file and or show it. @@ -1779,8 +1749,9 @@ def export_figure( default_path: The default file path if no user filename is provided. default_filetype: The default filetype if the path doesnt end with a filetype. user_path: An optional user-specified file path. - show: Whether to display the figure (default: True). + show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None). save: Whether to save the figure (default: False). + dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi. Raises: ValueError: If no default filetype is provided and the path doesn't specify a filetype. @@ -1838,7 +1809,7 @@ def export_figure( plt.show() if save: - fig.savefig(str(filename), dpi=300) + fig.savefig(str(filename), dpi=dpi) plt.close(fig) # Close figure to free memory return fig, ax diff --git a/flixopt/results.py b/flixopt/results.py index 75f8f300e..cadb14236 100644 --- a/flixopt/results.py +++ b/flixopt/results.py @@ -107,6 +107,20 @@ class CalculationResults: ).mean() ``` + Configure automatic color management for plots: + + ```python + # Dict-based configuration: + results.setup_colors({'Solar*': 'Oranges', 'Wind*': 'Blues', 'Battery': 'green'}) + + # All plots automatically use configured colors (colors=None is the default) + results['ElectricityBus'].plot_node_balance() + results['Battery'].plot_charge_state() + + # Override when needed + results['ElectricityBus'].plot_node_balance(colors='turbo') # Ignores setup + ``` + Design Patterns: **Factory Methods**: Use `from_file()` and `from_calculation()` for creation or access directly from `Calculation.results` **Dictionary Access**: Use `results[element_label]` for element-specific results @@ -721,6 +735,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots a heatmap visualization of a variable using imshow or time-based reshaping. @@ -754,6 +769,20 @@ def plot_heatmap( Supported timeframes: 'YS', 'MS', 'W', 'D', 'h', '15min', 'min' fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill). Default is 'ffill'. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + + For heatmaps specifically: + + - **vmin** (float): Minimum value for color scale (both engines). + - **vmax** (float): Maximum value for color scale (both engines). + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow (e.g., interpolation, aspect). + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Examples: Direct imshow mode (default): @@ -794,6 +823,18 @@ def plot_heatmap( ... animate_by='period', ... reshape_time=('D', 'h'), ... ) + + High-resolution export with custom color range: + + >>> results.plot_heatmap('Battery|charge_state', save=True, dpi=600, vmin=0, vmax=100) + + Matplotlib heatmap with custom imshow settings: + + >>> results.plot_heatmap( + ... 'Boiler(Q_th)|flow_rate', + ... engine='matplotlib', + ... imshow_kwargs={'interpolation': 'bilinear', 'aspect': 'auto'}, + ... ) """ # Delegate to module-level plot_heatmap function return plot_heatmap( @@ -814,6 +855,7 @@ def plot_heatmap( heatmap_timeframes=heatmap_timeframes, heatmap_timesteps_per_frame=heatmap_timesteps_per_frame, color_map=color_map, + **plot_kwargs, ) def plot_network( @@ -994,6 +1036,7 @@ def plot_node_balance( facet_cols: int = 3, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots the node balance of the Component or Bus with optional faceting and animation. @@ -1021,6 +1064,27 @@ def plot_node_balance( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine** (`engine='plotly'`): + + - **trace_kwargs** (dict): Customize traces via `fig.update_traces()`. + Example: `trace_kwargs={'line': {'width': 5, 'dash': 'dot'}}` + - **layout_kwargs** (dict): Customize layout via `fig.update_layout()`. + Example: `layout_kwargs={'width': 1200, 'height': 600, 'template': 'plotly_dark'}` + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + + **For Matplotlib engine** (`engine='matplotlib'`): + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + Example: `plot_kwargs={'linewidth': 3, 'alpha': 0.7, 'edgecolor': 'black'}` + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. Examples: Basic plot (current behavior): @@ -1052,6 +1116,24 @@ def plot_node_balance( Time range selection (summer months only): >>> results['Boiler'].plot_node_balance(select={'time': slice('2024-06', '2024-08')}, facet_by='scenario') + + High-resolution export for publication: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', save='figure.png', dpi=600) + + Custom Plotly theme and layout: + + >>> results['Boiler'].plot_node_balance( + ... layout_kwargs={'template': 'plotly_dark', 'width': 1200, 'height': 600} + ... ) + + Custom line styling: + + >>> results['Boiler'].plot_node_balance(mode='line', trace_kwargs={'line': {'width': 5, 'dash': 'dot'}}) + + Custom matplotlib appearance: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', plot_kwargs={'linewidth': 3, 'alpha': 0.7}) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1073,8 +1155,11 @@ def plot_node_balance( if engine not in {'plotly', 'matplotlib'}: raise ValueError(f'Engine "{engine}" not supported. Use one of ["plotly", "matplotlib"]') + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Don't pass select/indexer to node_balance - we'll apply it afterwards - ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix) + ds = self.node_balance(with_last_timestep=False, unit_type=unit_type, drop_suffix=drop_suffix) ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True) @@ -1101,14 +1186,17 @@ def plot_node_balance( mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) default_filetype = '.html' else: figure_like = plotting.with_matplotlib( - ds.to_dataframe(), + ds, colors=colors, mode=mode, title=title, + **plot_kwargs, ) default_filetype = '.png' @@ -1119,6 +1207,7 @@ def plot_node_balance( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def plot_node_balance_pie( @@ -1132,6 +1221,7 @@ def plot_node_balance_pie( select: dict[FlowSystemDimensions, Any] | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, list[plt.Axes]]: """Plot pie chart of flow hours distribution. @@ -1151,6 +1241,17 @@ def plot_node_balance_pie( engine: Plotting engine ('plotly' or 'matplotlib'). select: Optional data selection dict. Supports single values, lists, slices, and index arrays. Use this to select specific scenario/period before creating the pie chart. + **plot_kwargs: Additional plotting customization options. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + - **hover_template** (str): Hover text template (Plotly only). + Example: `hover_template='%{label}: %{value} (%{percent})'` + - **text_position** (str): Text position ('inside', 'outside', 'auto'). + - **hole** (float): Size of donut hole (0.0 to 1.0). + + See :func:`flixopt.plotting.dual_pie_with_plotly` for complete reference. Examples: Basic usage (auto-selects first scenario/period if present): @@ -1160,6 +1261,14 @@ def plot_node_balance_pie( Explicitly select a scenario and period: >>> results['Bus'].plot_node_balance_pie(select={'scenario': 'high_demand', 'period': 2030}) + + Create a donut chart with custom hover text: + + >>> results['Bus'].plot_node_balance_pie(hole=0.4, hover_template='%{label}: %{value:.2f} (%{percent})') + + High-resolution export: + + >>> results['Bus'].plot_node_balance_pie(save='figure.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1178,6 +1287,9 @@ def plot_node_balance_pie( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + inputs = sanitize_dataset( ds=self.solution[self.inputs] * self._calculation_results.hours_per_timestep, threshold=1e-5, @@ -1235,14 +1347,15 @@ def plot_node_balance_pie( if engine == 'plotly': figure_like = plotting.dual_pie_with_plotly( - data_left=inputs.to_pandas(), - data_right=outputs.to_pandas(), + data_left=inputs, + data_right=outputs, colors=colors, title=title, text_info=text_info, subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -1255,6 +1368,7 @@ def plot_node_balance_pie( subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.png' else: @@ -1267,6 +1381,7 @@ def plot_node_balance_pie( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance( @@ -1373,6 +1488,7 @@ def plot_charge_state( facet_cols: int = 3, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure: """Plot storage charge state over time, combined with the node balance with optional faceting and animation. @@ -1389,6 +1505,24 @@ def plot_charge_state( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine:** + + - **trace_kwargs** (dict): Customize traces via `fig.update_traces()`. + - **layout_kwargs** (dict): Customize layout via `fig.update_layout()`. + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + + **For Matplotlib engine:** + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. Raises: ValueError: If component is not a storage. @@ -1409,6 +1543,14 @@ def plot_charge_state( Facet by scenario AND animate by period: >>> results['Storage'].plot_charge_state(facet_by='scenario', animate_by='period') + + Custom layout: + + >>> results['Storage'].plot_charge_state(layout_kwargs={'template': 'plotly_dark', 'height': 800}) + + High-resolution export: + + >>> results['Storage'].plot_charge_state(save='storage.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1427,11 +1569,14 @@ def plot_charge_state( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + if not self.is_storage: raise ValueError(f'Cant plot charge_state. "{self.label}" is not a storage') # Get node balance and charge state - ds = self.node_balance(with_last_timestep=True) + ds = self.node_balance(with_last_timestep=True).fillna(0) charge_state_da = self.charge_state # Apply select filtering @@ -1451,11 +1596,12 @@ def plot_charge_state( mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) - # Create a dataset with just charge_state and plot it as lines - # This ensures proper handling of facets and animation - charge_state_ds = charge_state_da.to_dataset(name=self._charge_state) + # Prepare charge_state as Dataset for plotting + charge_state_ds = xr.Dataset({self._charge_state: charge_state_da}) # Plot charge_state with mode='line' to get Scatter traces charge_state_fig = plotting.with_plotly( @@ -1466,6 +1612,8 @@ def plot_charge_state( mode='line', # Always line for charge_state title='', # No title needed for this temp figure facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) # Add charge_state traces to the main figure @@ -1473,6 +1621,7 @@ def plot_charge_state( for trace in charge_state_fig.data: trace.line.width = 2 # Make charge_state line more prominent trace.line.shape = 'linear' # Smooth line for charge state (not stepped like flows) + trace.line.color = 'black' figure_like.add_trace(trace) # Also add traces from animation frames if they exist @@ -1497,10 +1646,11 @@ def plot_charge_state( ) # For matplotlib, plot flows (node balance), then add charge_state as line fig, ax = plotting.with_matplotlib( - ds.to_dataframe(), + ds, colors=colors, mode=mode, title=title, + **plot_kwargs, ) # Add charge_state as a line overlay @@ -1525,6 +1675,7 @@ def plot_charge_state( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance_with_charge_state( @@ -1810,6 +1961,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """Plot heatmap of variable solution across segments. @@ -1830,6 +1982,17 @@ def plot_heatmap( heatmap_timeframes: (Deprecated) Use reshape_time instead. heatmap_timesteps_per_frame: (Deprecated) Use reshape_time instead. color_map: (Deprecated) Use colors instead. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + - **vmin** (float): Minimum value for color scale. + - **vmax** (float): Maximum value for color scale. + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow. + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Returns: Figure object. @@ -1884,6 +2047,7 @@ def plot_heatmap( animate_by=animate_by, facet_cols=facet_cols, fill=fill, + **plot_kwargs, ) def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = None, compression: int = 5): @@ -1933,6 +2097,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ): """Plot heatmap visualization with support for multi-variable, faceting, and animation. @@ -2087,6 +2252,9 @@ def plot_heatmap( timeframes, timesteps_per_frame = reshape_time title += f' ({timeframes} vs {timesteps_per_frame})' + # Extract dpi before passing to plotting functions + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Plot with appropriate engine if engine == 'plotly': figure_like = plotting.heatmap_with_plotly( @@ -2098,6 +2266,7 @@ def plot_heatmap( facet_cols=facet_cols, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -2107,6 +2276,7 @@ def plot_heatmap( title=title, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.png' else: @@ -2123,6 +2293,7 @@ def plot_heatmap( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py new file mode 100644 index 000000000..f59601dca --- /dev/null +++ b/tests/test_plotting_api.py @@ -0,0 +1,64 @@ +"""Smoke tests for plotting API robustness improvements.""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from flixopt import plotting + + +@pytest.fixture +def sample_dataset(): + """Create a sample xarray Dataset for testing.""" + time = np.arange(10) + data = xr.Dataset( + { + 'var1': (['time'], np.random.rand(10)), + 'var2': (['time'], np.random.rand(10)), + 'var3': (['time'], np.random.rand(10)), + }, + coords={'time': time}, + ) + return data + + +@pytest.fixture +def sample_dataframe(): + """Create a sample pandas DataFrame for testing.""" + time = np.arange(10) + df = pd.DataFrame({'var1': np.random.rand(10), 'var2': np.random.rand(10), 'var3': np.random.rand(10)}, index=time) + df.index.name = 'time' + return df + + +def test_kwargs_passthrough_plotly(sample_dataset): + """Test that backend-specific kwargs are passed through correctly.""" + fig = plotting.with_plotly( + sample_dataset, + mode='line', + trace_kwargs={'line': {'width': 5}}, + layout_kwargs={'width': 1200, 'height': 600}, + ) + assert fig.layout.width == 1200 + assert fig.layout.height == 600 + + +def test_dataframe_support_plotly(sample_dataframe): + """Test that DataFrames are accepted by plotting functions.""" + fig = plotting.with_plotly(sample_dataframe, mode='line') + assert fig is not None + + +def test_data_validation_non_numeric(): + """Test that validation catches non-numeric data.""" + data = xr.Dataset({'var1': (['time'], ['a', 'b', 'c'])}, coords={'time': [0, 1, 2]}) + + with pytest.raises(TypeError, match='non-numeric dtype'): + plotting.with_plotly(data) + + +def test_ensure_dataset_invalid_type(): + """Test that _ensure_dataset raises error for invalid types.""" + with pytest.raises(TypeError, match='must be xr.Dataset or pd.DataFrame'): + plotting._ensure_dataset([1, 2, 3])