From 85911e8e3a93fc0e044563fc0c7ed9d2d355e5f8 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:29:39 +0100 Subject: [PATCH 01/30] Improve documentation and improve CHANGELOG.md --- CHANGELOG.md | 70 ++++++++++++++++++++++++++++- docs/user-guide/results-plotting.md | 3 ++ mkdocs.yml | 7 ++- 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f60497d80..191a8c28c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,7 +53,7 @@ Until here --> ## [5.1.0] - Upcoming -**Summary**: Time-series clustering for faster optimization with configurable storage behavior across typical periods. Improved weights API with always-normalized scenario weights. +**Summary**: Major feature release introducing time-series clustering for faster optimization and the new `fxplot` accessor for universal xarray plotting. Includes configurable storage behavior across typical periods and improved weights API. ### ✨ Added @@ -121,6 +121,44 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] Use `'cyclic'` for short-term storage like batteries or hot water tanks where only daily patterns matter. Use `'independent'` for quick estimates when storage behavior isn't critical. +**FXPlot Accessor**: New global xarray accessors for universal plotting with automatic faceting and smart dimension handling. Works on any xarray Dataset, not just flixopt results. + +```python +import flixopt as fx # Registers accessors automatically + +# Plot any xarray Dataset with automatic faceting +dataset.fxplot.bar(x='component') +dataset.fxplot.area(x='time') +dataset.fxplot.heatmap(x='time', y='component') +dataset.fxplot.line(x='time', facet_col='scenario') + +# DataArray support +data_array.fxplot.line() + +# Statistics transformations +dataset.fxstats.to_duration_curve() +``` + +**Available Plot Methods**: + +| Method | Description | +|--------|-------------| +| `.fxplot.bar()` | Grouped bar charts | +| `.fxplot.stacked_bar()` | Stacked bar charts | +| `.fxplot.line()` | Line charts with faceting | +| `.fxplot.area()` | Stacked area charts | +| `.fxplot.heatmap()` | Heatmap visualizations | +| `.fxplot.scatter()` | Scatter plots | +| `.fxplot.pie()` | Pie charts with faceting | +| `.fxstats.to_duration_curve()` | Transform to duration curve format | + +**Key Features**: + +- **Auto-faceting**: Automatically assigns extra dimensions (period, scenario, cluster) to `facet_col`, `facet_row`, or `animation_frame` +- **Smart x-axis**: Intelligently selects x dimension based on priority (time > duration > period > scenario) +- **Universal**: Works on any xarray Dataset/DataArray, not limited to flixopt +- **Configurable**: Customize via `CONFIG.Plotting` (colorscales, facet columns, line shapes) + ### 💥 Breaking Changes - `FlowSystem.scenario_weights` are now always normalized to sum to 1 when set (including after `.sel()` subsetting) @@ -134,10 +172,35 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] - `normalize_weights` parameter in `create_model()`, `build_model()`, `optimize()` +**Topology method name simplifications** (old names still work with deprecation warnings, removal in v6.0.0): + +| Old (v5.0.0) | New (v5.1.0) | +|--------------|--------------| +| `topology.plot_network()` | `topology.plot()` | +| `topology.start_network_app()` | `topology.start_app()` | +| `topology.stop_network_app()` | `topology.stop_app()` | +| `topology.network_infos()` | `topology.infos()` | + +Note: `topology.plot()` now renders a Sankey diagram. The old PyVis visualization is available via `topology.plot_legacy()`. + ### 🐛 Fixed - `temporal_weight` and `sum_temporal()` now use consistent implementation +### 📝 Docs + +**New Documentation Pages:** + +- [Time-Series Clustering Guide](https://flixopt.github.io/flixopt/latest/user-guide/optimization/clustering/) - Comprehensive guide to clustering workflows + +**New Jupyter Notebooks:** + +- **08c-clustering.ipynb** - Introduction to time-series clustering +- **08c2-clustering-storage-modes.ipynb** - Comparison of all 4 storage cluster modes +- **08d-clustering-multiperiod.ipynb** - Clustering with periods and scenarios +- **08e-clustering-internals.ipynb** - Understanding clustering internals +- **fxplot_accessor_demo.ipynb** - Demo of the new fxplot accessor + ### 👷 Development **New Test Suites for Clustering**: @@ -147,6 +210,11 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] - `TestMultiPeriodClustering`: Tests for clustering with periods and scenarios dimensions - `TestPeakSelection`: Tests for `time_series_for_high_peaks` and `time_series_for_low_peaks` parameters +**New Test Suites for Other Features**: + +- `test_clustering_io.py` - Tests for clustering serialization roundtrip +- `test_sel_isel_single_selection.py` - Tests for transform selection methods + --- diff --git a/docs/user-guide/results-plotting.md b/docs/user-guide/results-plotting.md index 1ecd26aa1..28e3d2b2b 100644 --- a/docs/user-guide/results-plotting.md +++ b/docs/user-guide/results-plotting.md @@ -2,6 +2,9 @@ After solving an optimization, flixOpt provides a powerful plotting API to visualize and analyze your results. The API is designed to be intuitive and chainable, giving you quick access to common plots while still allowing deep customization. +!!! tip "Plotting Custom Data" + For plotting arbitrary xarray data (not just flixopt results), see the [Custom Data Plotting](recipes/plotting-custom-data.md) guide which covers the `.fxplot` accessor. + ## The Plot Accessor All plotting is accessed through the `statistics.plot` accessor on your FlowSystem: diff --git a/mkdocs.yml b/mkdocs.yml index ab2e9309f..6ac519130 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,8 +69,13 @@ nav: - Piecewise Effects: notebooks/06c-piecewise-effects.ipynb - Scaling: - Scenarios: notebooks/07-scenarios-and-periods.ipynb - - Clustering: notebooks/08a-aggregation.ipynb + - Resampling: notebooks/08a-aggregation.ipynb - Rolling Horizon: notebooks/08b-rolling-horizon.ipynb + - Clustering: + - Basics: notebooks/08c-clustering.ipynb + - Storage Modes: notebooks/08c2-clustering-storage-modes.ipynb + - Multi-Period: notebooks/08d-clustering-multiperiod.ipynb + - Internals: notebooks/08e-clustering-internals.ipynb - Results: - Plotting: notebooks/09-plotting-and-data-access.ipynb - Custom Data Plotting: notebooks/fxplot_accessor_demo.ipynb From e4cd2701b5b69c09749cdc1d150faa6e6265fdd9 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:48:26 +0100 Subject: [PATCH 02/30] FIx CHangelog and change to v6.0.0 --- CHANGELOG.md | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 191a8c28c..9fe871469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,9 +51,12 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp Until here --> -## [5.1.0] - Upcoming +## [6.0.0] - Upcoming -**Summary**: Major feature release introducing time-series clustering for faster optimization and the new `fxplot` accessor for universal xarray plotting. Includes configurable storage behavior across typical periods and improved weights API. +**Summary**: Major release introducing time-series clustering with storage inter-cluster linking, the new `fxplot` accessor for universal xarray plotting, and removal of deprecated v5.0 classes. Includes configurable storage behavior across typical periods and improved weights API. + +!!! warning "Breaking Changes" + This release removes `ClusteredOptimization` and `ClusteringParameters` which were deprecated in v5.0.0. Use `flow_system.transform.cluster()` instead. See [Migration](#migration-from-clusteredoptimization) below. ### ✨ Added @@ -183,6 +186,42 @@ dataset.fxstats.to_duration_curve() Note: `topology.plot()` now renders a Sankey diagram. The old PyVis visualization is available via `topology.plot_legacy()`. +### 🔥 Removed + +**Clustering classes removed** (deprecated in v5.0.0): + +- `ClusteredOptimization` class - Use `flow_system.transform.cluster()` then `optimize()` +- `ClusteringParameters` class - Parameters are now passed directly to `transform.cluster()` +- `flixopt/clustering.py` module - Restructured to `flixopt/clustering/` package with new classes + +#### Migration from ClusteredOptimization + +=== "v5.x (Old - No longer works)" + ```python + from flixopt import ClusteredOptimization, ClusteringParameters + + params = ClusteringParameters(hours_per_period=24, nr_of_periods=8) + calc = ClusteredOptimization('model', flow_system, params) + calc.do_modeling_and_solve(solver) + results = calc.results + ``` + +=== "v6.0.0 (New)" + ```python + # Cluster using transform accessor + fs_clustered = flow_system.transform.cluster( + n_clusters=8, # was: nr_of_periods + cluster_duration='1D', # was: hours_per_period=24 + ) + fs_clustered.optimize(solver) + + # Results on the clustered FlowSystem + costs = fs_clustered.solution['costs'].item() + + # Expand back to full resolution if needed + fs_expanded = fs_clustered.transform.expand_solution() + ``` + ### 🐛 Fixed - `temporal_weight` and `sum_temporal()` now use consistent implementation From c20f94f5cd9e83d384a6c83f0d2fef6b4ad48a95 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:50:17 +0100 Subject: [PATCH 03/30] FIx CHangelog and change to v6.0.0 --- CHANGELOG.md | 23 ++++++++++++++++++++--- flixopt/config.py | 2 +- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fe871469..1d99886c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -173,12 +173,29 @@ dataset.fxstats.to_duration_curve() ### 🗑️ Deprecated +The following items are deprecated and will be removed in **v7.0.0**: + +**Classes** (use FlowSystem methods instead): + +- `Optimization` class → Use `flow_system.optimize(solver)` +- `SegmentedOptimization` class → Use `flow_system.optimize.rolling_horizon()` +- `Results` class → Use `flow_system.solution` and `flow_system.statistics` +- `SegmentedResults` class → Use segment FlowSystems directly + +**FlowSystem methods** (use `transform` accessor instead): + +- `flow_system.sel()` → Use `flow_system.transform.sel()` +- `flow_system.isel()` → Use `flow_system.transform.isel()` +- `flow_system.resample()` → Use `flow_system.transform.resample()` + +**Parameters:** + - `normalize_weights` parameter in `create_model()`, `build_model()`, `optimize()` -**Topology method name simplifications** (old names still work with deprecation warnings, removal in v6.0.0): +**Topology method name simplifications** (old names still work with deprecation warnings, removal in v7.0.0): -| Old (v5.0.0) | New (v5.1.0) | -|--------------|--------------| +| Old (v5.x) | New (v6.0.0) | +|------------|--------------| | `topology.plot_network()` | `topology.plot()` | | `topology.start_network_app()` | `topology.start_app()` | | `topology.stop_network_app()` | `topology.stop_app()` | diff --git a/flixopt/config.py b/flixopt/config.py index 454f8ad3e..602652252 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -30,7 +30,7 @@ logging.addLevelName(SUCCESS_LEVEL, 'SUCCESS') # Deprecation removal version - update this when planning the next major version -DEPRECATION_REMOVAL_VERSION = '6.0.0' +DEPRECATION_REMOVAL_VERSION = '7.0.0' class MultilineFormatter(logging.Formatter): From c6c9a75ce506e270560e15bbb1d546e235c561dc Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:50:37 +0100 Subject: [PATCH 04/30] FIx CHangelog and change to v6.0.0 --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d99886c8..a09539d81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -182,11 +182,15 @@ The following items are deprecated and will be removed in **v7.0.0**: - `Results` class → Use `flow_system.solution` and `flow_system.statistics` - `SegmentedResults` class → Use segment FlowSystems directly -**FlowSystem methods** (use `transform` accessor instead): +**FlowSystem methods** (use `transform` or `topology` accessor instead): - `flow_system.sel()` → Use `flow_system.transform.sel()` - `flow_system.isel()` → Use `flow_system.transform.isel()` - `flow_system.resample()` → Use `flow_system.transform.resample()` +- `flow_system.plot_network()` → Use `flow_system.topology.plot()` +- `flow_system.start_network_app()` → Use `flow_system.topology.start_app()` +- `flow_system.stop_network_app()` → Use `flow_system.topology.stop_app()` +- `flow_system.network_infos()` → Use `flow_system.topology.infos()` **Parameters:** From 3d8e6008ed25bcfc1c521fdf6fa734f56c7ce240 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 11:49:26 +0100 Subject: [PATCH 05/30] Enhanced Clustering Control New Parameters Added to cluster() Method | Parameter | Type | Default | Purpose | |-------------------------|-------------------------------|----------------------|--------------------------------------------------------------------------------------------------------------------| | cluster_method | Literal[...] | 'k_means' | Clustering algorithm ('k_means', 'hierarchical', 'k_medoids', 'k_maxoids', 'averaging') | | representation_method | Literal[...] | 'meanRepresentation' | How clusters are represented ('meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation') | | extreme_period_method | Literal[...] | 'new_cluster_center' | How peaks are integrated ('None', 'append', 'new_cluster_center', 'replace_cluster_center') | | rescale_cluster_periods | bool | True | Rescale clusters to match original means | | random_state | int | None | None | Random seed for reproducibility | | predef_cluster_order | np.ndarray | list[int] | None | None | Manual clustering assignments | | **tsam_kwargs | Any | - | Pass-through for any tsam parameter | Clustering Quality Metrics Access via fs.clustering.metrics after clustering - returns a DataFrame with RMSE, MAE, and other accuracy indicators per time series. Files Modified 1. flixopt/transform_accessor.py - Updated cluster() signature and tsam call 2. flixopt/clustering/base.py - Added metrics field to Clustering class 3. tests/test_clustering/test_integration.py - Added tests for new parameters 4. docs/user-guide/optimization/clustering.md - Updated documentation --- docs/user-guide/optimization/clustering.md | 56 ++++++++++++++++++ flixopt/clustering/base.py | 2 + flixopt/transform_accessor.py | 46 +++++++++++++-- tests/test_clustering/test_integration.py | 68 ++++++++++++++++++++++ 4 files changed, 168 insertions(+), 4 deletions(-) diff --git a/docs/user-guide/optimization/clustering.md b/docs/user-guide/optimization/clustering.md index 7ec5faac1..aca1b3eaf 100644 --- a/docs/user-guide/optimization/clustering.md +++ b/docs/user-guide/optimization/clustering.md @@ -52,6 +52,10 @@ flow_rates = fs_expanded.solution['Boiler(Q_th)|flow_rate'] | `cluster_duration` | Duration of each cluster | `'1D'`, `'24h'`, or `24` (hours) | | `time_series_for_high_peaks` | Time series where peak clusters must be captured | `['HeatDemand(Q)|fixed_relative_profile']` | | `time_series_for_low_peaks` | Time series where minimum clusters must be captured | `['SolarGen(P)|fixed_relative_profile']` | +| `cluster_method` | Clustering algorithm | `'k_means'`, `'hierarchical'`, `'k_medoids'` | +| `representation_method` | How clusters are represented | `'meanRepresentation'`, `'medoidRepresentation'` | +| `random_state` | Random seed for reproducibility | `42` | +| `rescale_cluster_periods` | Rescale clusters to match original means | `True` (default) | ### Peak Selection @@ -68,6 +72,58 @@ fs_clustered = flow_system.transform.cluster( Without peak selection, the clustering algorithm might average out extreme days, leading to undersized equipment. +### Advanced Clustering Options + +Fine-tune the clustering algorithm with advanced parameters: + +```python +fs_clustered = flow_system.transform.cluster( + n_clusters=8, + cluster_duration='1D', + cluster_method='hierarchical', # Alternative to k_means + representation_method='medoidRepresentation', # Use actual periods, not averages + rescale_cluster_periods=True, # Match original time series means + random_state=42, # Reproducible results +) +``` + +**Available clustering algorithms** (`cluster_method`): + +| Method | Description | +|--------|-------------| +| `'k_means'` | Fast, good for most cases (default) | +| `'hierarchical'` | Produces consistent hierarchical groupings | +| `'k_medoids'` | Uses actual periods as representatives | +| `'k_maxoids'` | Maximizes representativeness | +| `'averaging'` | Simple averaging of similar periods | + +For advanced tsam parameters not exposed directly, use `**kwargs`: + +```python +# Pass any tsam.TimeSeriesAggregation parameter +fs_clustered = flow_system.transform.cluster( + n_clusters=8, + cluster_duration='1D', + sameMean=True, # Normalize all time series to same mean + sortValues=True, # Cluster by duration curves instead of shape +) +``` + +### Clustering Quality Metrics + +After clustering, access quality metrics to evaluate the aggregation accuracy: + +```python +fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') + +# Access clustering metrics +metrics = fs_clustered.clustering.metrics +print(metrics) + +# Metrics include RMSE, MAE per time series +# Use these to assess if more clusters are needed +``` + ## Storage Modes Storage behavior during clustering is controlled via the `cluster_mode` parameter: diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 4b31832e4..34aba2ded 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -993,6 +993,7 @@ class Clustering: Attributes: result: The ClusterResult from the aggregation backend. backend_name: Name of the aggregation backend used (e.g., 'tsam', 'manual'). + metrics: Clustering quality metrics (RMSE, MAE, etc.) per time series. Example: >>> fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') @@ -1004,6 +1005,7 @@ class Clustering: result: ClusterResult backend_name: str = 'unknown' + metrics: pd.DataFrame | None = None def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: """Create reference structure for serialization.""" diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 3a13dbb63..010499f46 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -582,6 +582,17 @@ def cluster( weights: dict[str, float] | None = None, time_series_for_high_peaks: list[str] | None = None, time_series_for_low_peaks: list[str] | None = None, + cluster_method: Literal['k_means', 'k_medoids', 'hierarchical', 'k_maxoids', 'averaging'] = 'k_means', + representation_method: Literal[ + 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' + ] = 'meanRepresentation', + extreme_period_method: Literal[ + 'None', 'append', 'new_cluster_center', 'replace_cluster_center' + ] = 'new_cluster_center', + rescale_cluster_periods: bool = True, + random_state: int | None = None, + predef_cluster_order: np.ndarray | list[int] | None = None, + **tsam_kwargs: Any, ) -> FlowSystem: """ Create a FlowSystem with reduced timesteps using typical clusters. @@ -607,6 +618,24 @@ def cluster( time_series_for_high_peaks: Time series labels for explicitly selecting high-value clusters. **Recommended** for demand time series to capture peak demand days. time_series_for_low_peaks: Time series labels for explicitly selecting low-value clusters. + cluster_method: Clustering algorithm to use. Options: + ``'k_means'`` (default), ``'k_medoids'``, ``'hierarchical'``, + ``'k_maxoids'``, ``'averaging'``. + representation_method: How cluster representatives are computed. Options: + ``'meanRepresentation'`` (default), ``'medoidRepresentation'``, + ``'distributionAndMinMaxRepresentation'``. + extreme_period_method: How extreme periods (peaks) are integrated. Options: + ``'new_cluster_center'`` (default), ``'None'``, ``'append'``, + ``'replace_cluster_center'``. + rescale_cluster_periods: If True (default), rescale cluster periods so their + weighted mean matches the original time series mean. + random_state: Random seed for reproducible clustering results. If None, + results may vary between runs. + predef_cluster_order: Predefined cluster assignments for manual clustering. + Array of cluster indices (0 to n_clusters-1) for each original period. + If provided, clustering is skipped and these assignments are used directly. + **tsam_kwargs: Additional keyword arguments passed to + ``tsam.TimeSeriesAggregation``. See tsam documentation for all options. Returns: A new FlowSystem with reduced timesteps (only typical clusters). @@ -680,7 +709,10 @@ def cluster( tsam_results: dict[tuple, tsam.TimeSeriesAggregation] = {} cluster_orders: dict[tuple, np.ndarray] = {} cluster_occurrences_all: dict[tuple, dict] = {} - use_extreme_periods = bool(time_series_for_high_peaks or time_series_for_low_peaks) + + # Set random seed for reproducibility + if random_state is not None: + np.random.seed(random_state) for period_label in periods: for scenario_label in scenarios: @@ -700,11 +732,15 @@ def cluster( noTypicalPeriods=n_clusters, hoursPerPeriod=hours_per_cluster, resolution=dt, - clusterMethod='k_means', - extremePeriodMethod='new_cluster_center' if use_extreme_periods else 'None', + clusterMethod=cluster_method, + extremePeriodMethod=extreme_period_method, + representationMethod=representation_method, + rescaleClusterPeriods=rescale_cluster_periods, + predefClusterOrder=predef_cluster_order, weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, addPeakMax=time_series_for_high_peaks or [], addPeakMin=time_series_for_low_peaks or [], + **tsam_kwargs, ) # Suppress tsam warning about minimal value constraints (informational, not actionable) with warnings.catch_warnings(): @@ -715,9 +751,10 @@ def cluster( cluster_orders[key] = tsam_agg.clusterOrder cluster_occurrences_all[key] = tsam_agg.clusterPeriodNoOccur - # Use first result for structure + # Use first result for structure and metrics first_key = (periods[0], scenarios[0]) first_tsam = tsam_results[first_key] + clustering_metrics = first_tsam.accuracyIndicators() n_reduced_timesteps = len(first_tsam.typicalPeriods) actual_n_clusters = len(first_tsam.clusterPeriodNoOccur) @@ -932,6 +969,7 @@ def _build_cluster_weights_for_key(key: tuple) -> xr.DataArray: reduced_fs.clustering = Clustering( result=aggregation_result, backend_name='tsam', + metrics=clustering_metrics, ) return reduced_fs diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index 2d04a51c1..2bcd0b022 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -170,6 +170,74 @@ def test_cluster_reduces_timesteps(self): assert len(fs_clustered.timesteps) * len(fs_clustered.clusters) == 48 +class TestClusterAdvancedOptions: + """Tests for advanced clustering options.""" + + @pytest.fixture + def basic_flow_system(self): + """Create a basic FlowSystem for testing.""" + pytest.importorskip('tsam') + from flixopt import Bus, Flow, Sink, Source + from flixopt.core import TimeSeriesData + + n_hours = 168 # 7 days + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=n_hours, freq='h')) + + demand_data = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2 + bus = Bus('electricity') + grid_flow = Flow('grid_in', bus='electricity', size=100) + demand_flow = Flow( + 'demand_out', bus='electricity', size=100, fixed_relative_profile=TimeSeriesData(demand_data / 100) + ) + source = Source('grid', outputs=[grid_flow]) + sink = Sink('demand', inputs=[demand_flow]) + fs.add_elements(source, sink, bus) + return fs + + def test_cluster_method_parameter(self, basic_flow_system): + """Test that cluster_method parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', cluster_method='hierarchical' + ) + assert len(fs_clustered.clusters) == 2 + + def test_random_state_reproducibility(self, basic_flow_system): + """Test that random_state produces reproducible results.""" + fs1 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) + fs2 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) + + # Same random state should produce identical cluster orders + xr.testing.assert_equal(fs1.clustering.cluster_order, fs2.clustering.cluster_order) + + def test_metrics_available(self, basic_flow_system): + """Test that clustering metrics are available after clustering.""" + fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') + + assert fs_clustered.clustering.metrics is not None + assert isinstance(fs_clustered.clustering.metrics, pd.DataFrame) + assert len(fs_clustered.clustering.metrics) > 0 + + def test_representation_method_parameter(self, basic_flow_system): + """Test that representation_method parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', representation_method='medoidRepresentation' + ) + assert len(fs_clustered.clusters) == 2 + + def test_rescale_cluster_periods_parameter(self, basic_flow_system): + """Test that rescale_cluster_periods parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', rescale_cluster_periods=False + ) + assert len(fs_clustered.clusters) == 2 + + def test_tsam_kwargs_passthrough(self, basic_flow_system): + """Test that additional kwargs are passed to tsam.""" + # sameMean is a valid tsam parameter + fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', sameMean=True) + assert len(fs_clustered.clusters) == 2 + + class TestClusteringModuleImports: """Tests for flixopt.clustering module imports.""" From 0abdb002036c896d0b4d85b5d72702d5869c7da7 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 12:17:57 +0100 Subject: [PATCH 06/30] =?UTF-8?q?=20=20Dimension=20renamed:=20original=5Fp?= =?UTF-8?q?eriod=20=E2=86=92=20original=5Fcluster=20=20=20Property=20renam?= =?UTF-8?q?ed:=20n=5Foriginal=5Fperiods=20=E2=86=92=20n=5Foriginal=5Fclust?= =?UTF-8?q?ers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/user-guide/optimization/clustering.md | 8 +-- flixopt/clustering/base.py | 54 ++++++++------- flixopt/clustering/intercluster_helpers.py | 6 +- flixopt/components.py | 45 ++++++------ flixopt/transform_accessor.py | 81 +++++++++++++++++----- tests/test_cluster_reduce_expand.py | 6 +- tests/test_clustering/test_base.py | 10 +-- tests/test_clustering/test_integration.py | 34 ++++++++- 8 files changed, 164 insertions(+), 80 deletions(-) diff --git a/docs/user-guide/optimization/clustering.md b/docs/user-guide/optimization/clustering.md index aca1b3eaf..793fbf8fe 100644 --- a/docs/user-guide/optimization/clustering.md +++ b/docs/user-guide/optimization/clustering.md @@ -116,12 +116,12 @@ After clustering, access quality metrics to evaluate the aggregation accuracy: ```python fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') -# Access clustering metrics +# Access clustering metrics (xr.Dataset) metrics = fs_clustered.clustering.metrics -print(metrics) +print(metrics) # Shows RMSE, MAE, etc. per time series -# Metrics include RMSE, MAE per time series -# Use these to assess if more clusters are needed +# Access specific metric +rmse = metrics['RMSE'] # xr.DataArray with dims [time_series, period?, scenario?] ``` ## Storage Modes diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 34aba2ded..9c900593a 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -38,15 +38,15 @@ class ClusterStructure: which is needed for proper storage state-of-charge tracking across typical periods when using cluster(). - Note: "original_period" here refers to the original time chunks before - clustering (e.g., 365 original days), NOT the model's "period" dimension - (years/months). Each original time chunk gets assigned to a cluster. + Note: The "original_cluster" dimension indexes the original cluster-sized + time segments (e.g., 0..364 for 365 days), NOT the model's "period" dimension + (years). Each original segment gets assigned to a representative cluster. Attributes: - cluster_order: Maps each original time chunk index to its cluster ID. - dims: [original_period] for simple case, or - [original_period, period, scenario] for multi-period/scenario systems. - Values are cluster indices (0 to n_clusters-1). + cluster_order: Maps original cluster index → representative cluster ID. + dims: [original_cluster] for simple case, or + [original_cluster, period, scenario] for multi-period/scenario systems. + Values are cluster IDs (0 to n_clusters-1). cluster_occurrences: Count of how many original time chunks each cluster represents. dims: [cluster] for simple case, or [cluster, period, scenario] for multi-dim. n_clusters: Number of distinct clusters (typical periods). @@ -60,7 +60,7 @@ class ClusterStructure: - timesteps_per_cluster: 24 (for hourly data) For multi-scenario (e.g., 2 scenarios): - - cluster_order: shape (365, 2) with dims [original_period, scenario] + - cluster_order: shape (365, 2) with dims [original_cluster, scenario] - cluster_occurrences: shape (8, 2) with dims [cluster, scenario] """ @@ -73,7 +73,7 @@ def __post_init__(self): """Validate and ensure proper DataArray formatting.""" # Ensure cluster_order is a DataArray with proper dims if not isinstance(self.cluster_order, xr.DataArray): - self.cluster_order = xr.DataArray(self.cluster_order, dims=['original_period'], name='cluster_order') + self.cluster_order = xr.DataArray(self.cluster_order, dims=['original_cluster'], name='cluster_order') elif self.cluster_order.name is None: self.cluster_order = self.cluster_order.rename('cluster_order') @@ -92,7 +92,7 @@ def __repr__(self) -> str: occ = [int(self.cluster_occurrences.sel(cluster=c).values) for c in range(n_clusters)] return ( f'ClusterStructure(\n' - f' {self.n_original_periods} original periods → {n_clusters} clusters\n' + f' {self.n_original_clusters} original periods → {n_clusters} clusters\n' f' timesteps_per_cluster={self.timesteps_per_cluster}\n' f' occurrences={occ}\n' f')' @@ -124,9 +124,9 @@ def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: return ref, arrays @property - def n_original_periods(self) -> int: + def n_original_clusters(self) -> int: """Number of original periods (before clustering).""" - return len(self.cluster_order.coords['original_period']) + return len(self.cluster_order.coords['original_cluster']) @property def has_multi_dims(self) -> bool: @@ -236,7 +236,7 @@ def plot(self, show: bool | None = None) -> PlotResult: y=[1] * len(df), color='Cluster', color_continuous_scale='Viridis', - title=f'Cluster Assignment ({self.n_original_periods} periods → {n_clusters} clusters)', + title=f'Cluster Assignment ({self.n_original_clusters} periods → {n_clusters} clusters)', ) fig.update_layout(yaxis_visible=False, coloraxis_colorbar_title='Cluster') @@ -532,30 +532,30 @@ def validate(self) -> None: # (each weight is how many original periods that cluster represents) # Sum should be checked per period/scenario slice, not across all dimensions if self.cluster_structure is not None: - n_original_periods = self.cluster_structure.n_original_periods + n_original_clusters = self.cluster_structure.n_original_clusters # Sum over cluster dimension only (keep period/scenario if present) weight_sum_per_slice = self.representative_weights.sum(dim='cluster') # Check each slice if weight_sum_per_slice.size == 1: # Simple case: no period/scenario weight_sum = float(weight_sum_per_slice.values) - if abs(weight_sum - n_original_periods) > 1e-6: + if abs(weight_sum - n_original_clusters) > 1e-6: import warnings warnings.warn( f'representative_weights sum ({weight_sum}) does not match ' - f'n_original_periods ({n_original_periods})', + f'n_original_clusters ({n_original_clusters})', stacklevel=2, ) else: # Multi-dimensional: check each slice for val in weight_sum_per_slice.values.flat: - if abs(float(val) - n_original_periods) > 1e-6: + if abs(float(val) - n_original_clusters) > 1e-6: import warnings warnings.warn( f'representative_weights sum per slice ({float(val)}) does not match ' - f'n_original_periods ({n_original_periods})', + f'n_original_clusters ({n_original_clusters})', stacklevel=2, ) break # Only warn once @@ -993,7 +993,9 @@ class Clustering: Attributes: result: The ClusterResult from the aggregation backend. backend_name: Name of the aggregation backend used (e.g., 'tsam', 'manual'). - metrics: Clustering quality metrics (RMSE, MAE, etc.) per time series. + metrics: Clustering quality metrics (RMSE, MAE, etc.) as xr.Dataset. + Each metric (e.g., 'RMSE', 'MAE') is a DataArray with dims + ``[time_series, period?, scenario?]``. Example: >>> fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') @@ -1005,7 +1007,7 @@ class Clustering: result: ClusterResult backend_name: str = 'unknown' - metrics: pd.DataFrame | None = None + metrics: xr.Dataset | None = None def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: """Create reference structure for serialization.""" @@ -1028,7 +1030,7 @@ def __repr__(self) -> str: n_clusters = ( int(cs.n_clusters) if isinstance(cs.n_clusters, (int, np.integer)) else int(cs.n_clusters.values) ) - structure_info = f'{cs.n_original_periods} periods → {n_clusters} clusters' + structure_info = f'{cs.n_original_clusters} periods → {n_clusters} clusters' else: structure_info = 'no structure' return f'Clustering(\n backend={self.backend_name!r}\n {structure_info}\n)' @@ -1073,11 +1075,11 @@ def n_clusters(self) -> int: return int(n) if isinstance(n, (int, np.integer)) else int(n.values) @property - def n_original_periods(self) -> int: + def n_original_clusters(self) -> int: """Number of original periods (before clustering).""" if self.result.cluster_structure is None: raise ValueError('No cluster_structure available') - return self.result.cluster_structure.n_original_periods + return self.result.cluster_structure.n_original_clusters @property def timesteps_per_period(self) -> int: @@ -1154,17 +1156,17 @@ def create_cluster_structure_from_mapping( ClusterStructure derived from the mapping. """ n_original = len(timestep_mapping) - n_original_periods = n_original // timesteps_per_cluster + n_original_clusters = n_original // timesteps_per_cluster # Determine cluster order from the mapping # Each original period maps to the cluster of its first timestep cluster_order = [] - for p in range(n_original_periods): + for p in range(n_original_clusters): start_idx = p * timesteps_per_cluster cluster_idx = int(timestep_mapping.isel(original_time=start_idx).values) // timesteps_per_cluster cluster_order.append(cluster_idx) - cluster_order_da = xr.DataArray(cluster_order, dims=['original_period'], name='cluster_order') + cluster_order_da = xr.DataArray(cluster_order, dims=['original_cluster'], name='cluster_order') # Count occurrences of each cluster unique_clusters = np.unique(cluster_order) diff --git a/flixopt/clustering/intercluster_helpers.py b/flixopt/clustering/intercluster_helpers.py index d2a5eb9d3..a89a80862 100644 --- a/flixopt/clustering/intercluster_helpers.py +++ b/flixopt/clustering/intercluster_helpers.py @@ -132,7 +132,7 @@ def extract_capacity_bounds( def build_boundary_coords( - n_original_periods: int, + n_original_clusters: int, flow_system: FlowSystem, ) -> tuple[dict, list[str]]: """Build coordinates and dimensions for SOC_boundary variable. @@ -146,7 +146,7 @@ def build_boundary_coords( multi-period or stochastic optimizations. Args: - n_original_periods: Number of original (non-aggregated) time periods. + n_original_clusters: Number of original (non-aggregated) time periods. For example, if a year is clustered into 8 typical days but originally had 365 days, this would be 365. flow_system: The FlowSystem containing optional period/scenario dimensions. @@ -163,7 +163,7 @@ def build_boundary_coords( >>> coords['cluster_boundary'] array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) """ - n_boundaries = n_original_periods + 1 + n_boundaries = n_original_clusters + 1 coords = {'cluster_boundary': np.arange(n_boundaries)} dims = ['cluster_boundary'] diff --git a/flixopt/components.py b/flixopt/components.py index 390fc6f02..e962791d8 100644 --- a/flixopt/components.py +++ b/flixopt/components.py @@ -1195,7 +1195,7 @@ class InterclusterStorageModel(StorageModel): Variables Created ----------------- - ``SOC_boundary``: Absolute SOC at each original period boundary. - Shape: (n_original_periods + 1,) plus any period/scenario dimensions. + Shape: (n_original_clusters + 1,) plus any period/scenario dimensions. Constraints Created ------------------- @@ -1330,7 +1330,7 @@ def _add_intercluster_linking(self) -> None: else int(cluster_structure.n_clusters.values) ) timesteps_per_cluster = cluster_structure.timesteps_per_cluster - n_original_periods = cluster_structure.n_original_periods + n_original_clusters = cluster_structure.n_original_clusters cluster_order = cluster_structure.cluster_order # 1. Constrain ΔE = 0 at cluster starts @@ -1338,7 +1338,7 @@ def _add_intercluster_linking(self) -> None: # 2. Create SOC_boundary variable flow_system = self._model.flow_system - boundary_coords, boundary_dims = build_boundary_coords(n_original_periods, flow_system) + boundary_coords, boundary_dims = build_boundary_coords(n_original_clusters, flow_system) capacity_bounds = extract_capacity_bounds(self.element.capacity_in_flow_hours, boundary_coords, boundary_dims) soc_boundary = self.add_variables( @@ -1360,12 +1360,14 @@ def _add_intercluster_linking(self) -> None: delta_soc = self._compute_delta_soc(n_clusters, timesteps_per_cluster) # 5. Add linking constraints - self._add_linking_constraints(soc_boundary, delta_soc, cluster_order, n_original_periods, timesteps_per_cluster) + self._add_linking_constraints( + soc_boundary, delta_soc, cluster_order, n_original_clusters, timesteps_per_cluster + ) # 6. Add cyclic or initial constraint if self.element.cluster_mode == 'intercluster_cyclic': self.add_constraints( - soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_periods), + soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_clusters), short_name='cyclic', ) else: @@ -1375,7 +1377,8 @@ def _add_intercluster_linking(self) -> None: if isinstance(initial, str): # 'equals_final' means cyclic self.add_constraints( - soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_periods), + soc_boundary.isel(cluster_boundary=0) + == soc_boundary.isel(cluster_boundary=n_original_clusters), short_name='initial_SOC_boundary', ) else: @@ -1389,7 +1392,7 @@ def _add_intercluster_linking(self) -> None: soc_boundary, cluster_order, capacity_bounds.has_investment, - n_original_periods, + n_original_clusters, timesteps_per_cluster, ) @@ -1438,7 +1441,7 @@ def _add_linking_constraints( soc_boundary: xr.DataArray, delta_soc: xr.DataArray, cluster_order: xr.DataArray, - n_original_periods: int, + n_original_clusters: int, timesteps_per_cluster: int, ) -> None: """Add constraints linking consecutive SOC_boundary values. @@ -1455,17 +1458,17 @@ def _add_linking_constraints( soc_boundary: SOC_boundary variable. delta_soc: Net SOC change per cluster. cluster_order: Mapping from original periods to representative clusters. - n_original_periods: Number of original (non-clustered) periods. + n_original_clusters: Number of original (non-clustered) periods. timesteps_per_cluster: Number of timesteps in each cluster period. """ soc_after = soc_boundary.isel(cluster_boundary=slice(1, None)) soc_before = soc_boundary.isel(cluster_boundary=slice(None, -1)) # Rename for alignment - soc_after = soc_after.rename({'cluster_boundary': 'original_period'}) - soc_after = soc_after.assign_coords(original_period=np.arange(n_original_periods)) - soc_before = soc_before.rename({'cluster_boundary': 'original_period'}) - soc_before = soc_before.assign_coords(original_period=np.arange(n_original_periods)) + soc_after = soc_after.rename({'cluster_boundary': 'original_cluster'}) + soc_after = soc_after.assign_coords(original_cluster=np.arange(n_original_clusters)) + soc_before = soc_before.rename({'cluster_boundary': 'original_cluster'}) + soc_before = soc_before.assign_coords(original_cluster=np.arange(n_original_clusters)) # Get delta_soc for each original period using cluster_order delta_soc_ordered = delta_soc.isel(cluster=cluster_order) @@ -1484,7 +1487,7 @@ def _add_combined_bound_constraints( soc_boundary: xr.DataArray, cluster_order: xr.DataArray, has_investment: bool, - n_original_periods: int, + n_original_clusters: int, timesteps_per_cluster: int, ) -> None: """Add constraints ensuring actual SOC stays within bounds. @@ -1498,21 +1501,21 @@ def _add_combined_bound_constraints( middle, and end of each cluster. With 2D (cluster, time) structure, we simply select charge_state at a - given time offset, then reorder by cluster_order to get original_period order. + given time offset, then reorder by cluster_order to get original_cluster order. Args: soc_boundary: SOC_boundary variable. cluster_order: Mapping from original periods to clusters. has_investment: Whether the storage has investment sizing. - n_original_periods: Number of original periods. + n_original_clusters: Number of original periods. timesteps_per_cluster: Timesteps in each cluster. """ charge_state = self.charge_state # soc_d: SOC at start of each original period soc_d = soc_boundary.isel(cluster_boundary=slice(None, -1)) - soc_d = soc_d.rename({'cluster_boundary': 'original_period'}) - soc_d = soc_d.assign_coords(original_period=np.arange(n_original_periods)) + soc_d = soc_d.rename({'cluster_boundary': 'original_cluster'}) + soc_d = soc_d.assign_coords(original_cluster=np.arange(n_original_clusters)) # Get self-discharge rate for decay calculation # Keep as DataArray to respect per-period/scenario values @@ -1523,13 +1526,13 @@ def _add_combined_bound_constraints( for sample_name, offset in zip(['start', 'mid', 'end'], sample_offsets, strict=False): # With 2D structure: select time offset, then reorder by cluster_order cs_at_offset = charge_state.isel(time=offset) # Shape: (cluster, ...) - # Reorder to original_period order using cluster_order indexer + # Reorder to original_cluster order using cluster_order indexer cs_t = cs_at_offset.isel(cluster=cluster_order) # Suppress xarray warning about index loss - we immediately assign new coords anyway with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='.*does not create an index anymore.*') - cs_t = cs_t.rename({'cluster': 'original_period'}) - cs_t = cs_t.assign_coords(original_period=np.arange(n_original_periods)) + cs_t = cs_t.rename({'cluster': 'original_cluster'}) + cs_t = cs_t.assign_coords(original_cluster=np.arange(n_original_clusters)) # Apply decay factor (1-loss)^t to SOC_boundary per Eq. 9 decay_t = (1 - rel_loss) ** offset diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 010499f46..51fcb6f6f 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -591,7 +591,7 @@ def cluster( ] = 'new_cluster_center', rescale_cluster_periods: bool = True, random_state: int | None = None, - predef_cluster_order: np.ndarray | list[int] | None = None, + predef_cluster_order: xr.DataArray | np.ndarray | list[int] | None = None, **tsam_kwargs: Any, ) -> FlowSystem: """ @@ -634,6 +634,9 @@ def cluster( predef_cluster_order: Predefined cluster assignments for manual clustering. Array of cluster indices (0 to n_clusters-1) for each original period. If provided, clustering is skipped and these assignments are used directly. + For multi-dimensional FlowSystems, use an xr.DataArray with dims + ``[original_cluster, period?, scenario?]`` to specify different assignments + per period/scenario combination. **tsam_kwargs: Additional keyword arguments passed to ``tsam.TimeSeriesAggregation``. See tsam documentation for all options. @@ -714,6 +717,9 @@ def cluster( if random_state is not None: np.random.seed(random_state) + # Collect metrics per (period, scenario) slice + clustering_metrics_all: dict[tuple, pd.DataFrame] = {} + for period_label in periods: for scenario_label in scenarios: key = (period_label, scenario_label) @@ -725,6 +731,16 @@ def cluster( if selector: logger.info(f'Clustering {", ".join(f"{k}={v}" for k, v in selector.items())}...') + # Handle predef_cluster_order for multi-dimensional case + predef_order_slice = None + if predef_cluster_order is not None: + if isinstance(predef_cluster_order, xr.DataArray): + # Extract slice for this (period, scenario) combination + predef_order_slice = predef_cluster_order.sel(**selector, drop=True).values + else: + # Simple array/list - use directly + predef_order_slice = predef_cluster_order + # Use tsam directly clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) tsam_agg = tsam.TimeSeriesAggregation( @@ -736,7 +752,7 @@ def cluster( extremePeriodMethod=extreme_period_method, representationMethod=representation_method, rescaleClusterPeriods=rescale_cluster_periods, - predefClusterOrder=predef_cluster_order, + predefClusterOrder=predef_order_slice, weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, addPeakMax=time_series_for_high_peaks or [], addPeakMin=time_series_for_low_peaks or [], @@ -750,11 +766,44 @@ def cluster( tsam_results[key] = tsam_agg cluster_orders[key] = tsam_agg.clusterOrder cluster_occurrences_all[key] = tsam_agg.clusterPeriodNoOccur + clustering_metrics_all[key] = tsam_agg.accuracyIndicators() - # Use first result for structure and metrics + # Use first result for structure first_key = (periods[0], scenarios[0]) first_tsam = tsam_results[first_key] - clustering_metrics = first_tsam.accuracyIndicators() + + # Convert metrics to xr.Dataset with period/scenario dims if multi-dimensional + if len(clustering_metrics_all) == 1: + # Simple case: convert single DataFrame to Dataset + metrics_df = clustering_metrics_all[first_key] + clustering_metrics = xr.Dataset( + { + col: xr.DataArray( + metrics_df[col].values, dims=['time_series'], coords={'time_series': metrics_df.index} + ) + for col in metrics_df.columns + } + ) + else: + # Multi-dim case: combine metrics into Dataset with period/scenario dims + # First, get the metric columns from any DataFrame + sample_df = next(iter(clustering_metrics_all.values())) + metric_names = list(sample_df.columns) + time_series_names = list(sample_df.index) + + # Build DataArrays for each metric + data_vars = {} + for metric in metric_names: + # Shape: (time_series, period?, scenario?) + slices = {} + for (p, s), df in clustering_metrics_all.items(): + slices[(p, s)] = xr.DataArray(df[metric].values, dims=['time_series']) + + da = self._combine_slices_to_dataarray_generic(slices, ['time_series'], periods, scenarios, metric) + da = da.assign_coords(time_series=time_series_names) + data_vars[metric] = da + + clustering_metrics = xr.Dataset(data_vars) n_reduced_timesteps = len(first_tsam.typicalPeriods) actual_n_clusters = len(first_tsam.clusterPeriodNoOccur) @@ -888,7 +937,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: # Build multi-dimensional arrays if has_periods or has_scenarios: # Multi-dimensional case: build arrays for each (period, scenario) combination - # cluster_order: dims [original_period, period?, scenario?] + # cluster_order: dims [original_cluster, period?, scenario?] cluster_order_slices = {} timestep_mapping_slices = {} cluster_occurrences_slices = {} @@ -900,7 +949,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: for s in scenarios: key = (p, s) cluster_order_slices[key] = xr.DataArray( - cluster_orders[key], dims=['original_period'], name='cluster_order' + cluster_orders[key], dims=['original_cluster'], name='cluster_order' ) timestep_mapping_slices[key] = xr.DataArray( _build_timestep_mapping_for_key(key), @@ -914,7 +963,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: # Combine slices into multi-dimensional DataArrays cluster_order_da = self._combine_slices_to_dataarray_generic( - cluster_order_slices, ['original_period'], periods, scenarios, 'cluster_order' + cluster_order_slices, ['original_cluster'], periods, scenarios, 'cluster_order' ) timestep_mapping_da = self._combine_slices_to_dataarray_generic( timestep_mapping_slices, ['original_time'], periods, scenarios, 'timestep_mapping' @@ -924,7 +973,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: ) else: # Simple case: single (None, None) slice - cluster_order_da = xr.DataArray(cluster_orders[first_key], dims=['original_period'], name='cluster_order') + cluster_order_da = xr.DataArray(cluster_orders[first_key], dims=['original_cluster'], name='cluster_order') # Use renamed timesteps as coordinates original_timesteps_coord = self._fs.timesteps.rename('original_time') timestep_mapping_da = xr.DataArray( @@ -1034,7 +1083,7 @@ def _combine_slices_to_dataarray_generic( Args: slices: Dict mapping (period, scenario) tuples to DataArrays. - base_dims: Base dimensions of each slice (e.g., ['original_period'] or ['original_time']). + base_dims: Base dimensions of each slice (e.g., ['original_cluster'] or ['original_time']). periods: List of period labels ([None] if no periods dimension). scenarios: List of scenario labels ([None] if no scenarios dimension). name: Name for the resulting DataArray. @@ -1123,7 +1172,7 @@ def expand_solution(self) -> FlowSystem: disaggregates the FlowSystem by: 1. Expanding all time series data from typical clusters to full timesteps 2. Expanding the solution by mapping each typical cluster back to all - original segments it represents + original clusters it represents For FlowSystems with periods and/or scenarios, each (period, scenario) combination is expanded using its own cluster assignment. @@ -1159,7 +1208,7 @@ def expand_solution(self) -> FlowSystem: Note: The expanded FlowSystem repeats the typical cluster values for all - segments belonging to the same cluster. Both input data and solution + original clusters belonging to the same cluster. Both input data and solution are consistently expanded, so they match. This is an approximation - the actual dispatch at full resolution would differ due to intra-cluster variations in time series data. @@ -1261,12 +1310,12 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_charge_state = expanded_fs._solution[charge_state_name] # Map each original timestep to its original period index - original_period_indices = np.arange(n_original_timesteps) // timesteps_per_cluster + original_cluster_indices = np.arange(n_original_timesteps) // timesteps_per_cluster # Select SOC_boundary for each timestep (boundary[d] for period d) - # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_periods-1 + # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_clusters-1 soc_boundary_per_timestep = soc_boundary.isel( - cluster_boundary=xr.DataArray(original_period_indices, dims=['time']) + cluster_boundary=xr.DataArray(original_cluster_indices, dims=['time']) ) soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps) @@ -1293,14 +1342,14 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_fs._solution[charge_state_name] = combined_charge_state.assign_attrs(expanded_charge_state.attrs) n_combinations = len(periods) * len(scenarios) - n_original_segments = cluster_structure.n_original_periods + n_original_clusters = cluster_structure.n_original_clusters logger.info( f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps ' f'({n_clusters} clusters' + ( f', {n_combinations} period/scenario combinations)' if n_combinations > 1 - else f' → {n_original_segments} original segments)' + else f' → {n_original_clusters} original clusters)' ) ) diff --git a/tests/test_cluster_reduce_expand.py b/tests/test_cluster_reduce_expand.py index 7072fe22e..b54eeb56e 100644 --- a/tests/test_cluster_reduce_expand.py +++ b/tests/test_cluster_reduce_expand.py @@ -449,9 +449,9 @@ def test_storage_cluster_mode_intercluster(self, solver_fixture, timesteps_8_day soc_boundary = fs_clustered.solution['Battery|SOC_boundary'] assert 'cluster_boundary' in soc_boundary.dims - # Number of boundaries = n_original_periods + 1 - n_original_periods = fs_clustered.clustering.result.cluster_structure.n_original_periods - assert soc_boundary.sizes['cluster_boundary'] == n_original_periods + 1 + # Number of boundaries = n_original_clusters + 1 + n_original_clusters = fs_clustered.clustering.result.cluster_structure.n_original_clusters + assert soc_boundary.sizes['cluster_boundary'] == n_original_clusters + 1 def test_storage_cluster_mode_intercluster_cyclic(self, solver_fixture, timesteps_8_days): """Storage with cluster_mode='intercluster_cyclic' - linked with yearly cycling.""" diff --git a/tests/test_clustering/test_base.py b/tests/test_clustering/test_base.py index 9c63f25f6..9cca4de81 100644 --- a/tests/test_clustering/test_base.py +++ b/tests/test_clustering/test_base.py @@ -17,7 +17,7 @@ class TestClusterStructure: def test_basic_creation(self): """Test basic ClusterStructure creation.""" - cluster_order = xr.DataArray([0, 1, 0, 1, 2, 0], dims=['original_period']) + cluster_order = xr.DataArray([0, 1, 0, 1, 2, 0], dims=['original_cluster']) cluster_occurrences = xr.DataArray([3, 2, 1], dims=['cluster']) structure = ClusterStructure( @@ -29,7 +29,7 @@ def test_basic_creation(self): assert structure.n_clusters == 3 assert structure.timesteps_per_cluster == 24 - assert structure.n_original_periods == 6 + assert structure.n_original_clusters == 6 def test_creation_from_numpy(self): """Test ClusterStructure creation from numpy arrays.""" @@ -42,12 +42,12 @@ def test_creation_from_numpy(self): assert isinstance(structure.cluster_order, xr.DataArray) assert isinstance(structure.cluster_occurrences, xr.DataArray) - assert structure.n_original_periods == 5 + assert structure.n_original_clusters == 5 def test_get_cluster_weight_per_timestep(self): """Test weight calculation per timestep.""" structure = ClusterStructure( - cluster_order=xr.DataArray([0, 1, 0], dims=['original_period']), + cluster_order=xr.DataArray([0, 1, 0], dims=['original_cluster']), cluster_occurrences=xr.DataArray([2, 1], dims=['cluster']), n_clusters=2, timesteps_per_cluster=4, @@ -136,7 +136,7 @@ def test_basic_creation(self): structure = create_cluster_structure_from_mapping(mapping, timesteps_per_cluster=4) assert structure.timesteps_per_cluster == 4 - assert structure.n_original_periods == 3 + assert structure.n_original_clusters == 3 class TestClustering: diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index 2bcd0b022..d6dd3d2e7 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -214,8 +214,9 @@ def test_metrics_available(self, basic_flow_system): fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') assert fs_clustered.clustering.metrics is not None - assert isinstance(fs_clustered.clustering.metrics, pd.DataFrame) - assert len(fs_clustered.clustering.metrics) > 0 + assert isinstance(fs_clustered.clustering.metrics, xr.Dataset) + assert 'time_series' in fs_clustered.clustering.metrics.dims + assert len(fs_clustered.clustering.metrics.data_vars) > 0 def test_representation_method_parameter(self, basic_flow_system): """Test that representation_method parameter works.""" @@ -237,6 +238,35 @@ def test_tsam_kwargs_passthrough(self, basic_flow_system): fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', sameMean=True) assert len(fs_clustered.clusters) == 2 + def test_metrics_with_periods(self): + """Test that metrics have period dimension for multi-period FlowSystems.""" + pytest.importorskip('tsam') + from flixopt import Bus, Flow, Sink, Source + from flixopt.core import TimeSeriesData + + n_hours = 168 # 7 days + fs = FlowSystem( + timesteps=pd.date_range('2024-01-01', periods=n_hours, freq='h'), + periods=pd.Index([2025, 2030], name='period'), + ) + + demand_data = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2 + bus = Bus('electricity') + grid_flow = Flow('grid_in', bus='electricity', size=100) + demand_flow = Flow( + 'demand_out', bus='electricity', size=100, fixed_relative_profile=TimeSeriesData(demand_data / 100) + ) + source = Source('grid', outputs=[grid_flow]) + sink = Sink('demand', inputs=[demand_flow]) + fs.add_elements(source, sink, bus) + + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Metrics should have period dimension + assert fs_clustered.clustering.metrics is not None + assert 'period' in fs_clustered.clustering.metrics.dims + assert len(fs_clustered.clustering.metrics.period) == 2 + class TestClusteringModuleImports: """Tests for flixopt.clustering module imports.""" From 21f96c2d1d3cccb040a6b917658f18fab4d380fe Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 12:45:49 +0100 Subject: [PATCH 07/30] Problem: Expanded FlowSystem from clustering didn't have the extra timestep that regular FlowSystems have. Root Cause: In expand_solution(), the solution was only indexed by original_timesteps (n elements) instead of original_timesteps_extra (n+1 elements). Fix in flixopt/transform_accessor.py: 1. Reindex solution to timesteps_extra (line 1296-1298): - Added expanded_fs._solution.reindex(time=original_timesteps_extra) for consistency with non-expanded FlowSystems 2. Fill extra timestep for charge_state (lines 1300-1333): - Added special handling to properly fill the extra timestep for storage charge_state variables using the last cluster's extra timestep value 3. Updated intercluster storage handling (lines 1340-1388): - Modified to work with original_timesteps_extra instead of just original_timesteps - The extra timestep now correctly gets the final SOC boundary value with proper decay applied Tests updated in tests/test_cluster_reduce_expand.py: - Updated 4 assertions that check solution time coordinates to expect 193 (192 + 1 extra) instead of 192 --- flixopt/transform_accessor.py | 51 ++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 51fcb6f6f..b466d928a 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -1249,18 +1249,38 @@ def expand_solution(self) -> FlowSystem: scenarios = list(self._fs.scenarios) if has_scenarios else [None] n_original_timesteps = len(original_timesteps) n_reduced_timesteps = n_clusters * timesteps_per_cluster + n_original_clusters = cluster_structure.n_original_clusters # Expand function using ClusterResult.expand_data() - handles multi-dimensional cases - def expand_da(da: xr.DataArray) -> xr.DataArray: + # For charge_state with cluster dim, also includes the extra timestep + last_original_cluster_idx = (n_original_timesteps - 1) // timesteps_per_cluster + + def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: if 'time' not in da.dims: return da.copy() - return info.result.expand_data(da, original_time=original_timesteps) + expanded = info.result.expand_data(da, original_time=original_timesteps) + + # For charge_state with cluster dim, append the extra timestep value + if var_name.endswith('|charge_state') and 'cluster' in da.dims: + # Get extra timestep from last cluster using vectorized selection + cluster_order = cluster_structure.cluster_order # (n_original_clusters,) or with period/scenario + if cluster_order.ndim == 1: + last_cluster = int(cluster_order[last_original_cluster_idx]) + extra_val = da.isel(cluster=last_cluster, time=-1) + else: + # Multi-dimensional: select last cluster for each period/scenario slice + last_clusters = cluster_order.isel(original_cluster=last_original_cluster_idx) + extra_val = da.isel(cluster=last_clusters, time=-1) + extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]]) + expanded = xr.concat([expanded, extra_val], dim='time') + + return expanded # 1. Expand FlowSystem data (with cluster_weight set to 1.0 for all timesteps) reduced_ds = self._fs.to_dataset(include_solution=False) # Filter out cluster-related variables and copy attrs without clustering info data_vars = { - name: expand_da(da) + name: expand_da(da, name) for name, da in reduced_ds.data_vars.items() if name != 'cluster_weight' and not name.startswith('clustering|') } @@ -1288,17 +1308,22 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_fs = FlowSystem.from_dataset(expanded_ds) # 2. Expand solution + # charge_state variables get their extra timestep via expand_da; others get NaN via reindex reduced_solution = self._fs.solution expanded_fs._solution = xr.Dataset( - {name: expand_da(da) for name, da in reduced_solution.data_vars.items()}, + {name: expand_da(da, name) for name, da in reduced_solution.data_vars.items()}, attrs=reduced_solution.attrs, ) + # Reindex to timesteps_extra for consistency with non-expanded FlowSystems + # (variables without extra timestep data will have NaN at the final timestep) + expanded_fs._solution = expanded_fs._solution.reindex(time=original_timesteps_extra) # 3. Combine charge_state with SOC_boundary for InterclusterStorageModel storages # For intercluster storages, charge_state is relative (ΔE) and can be negative. # Per Blanke et al. (2022) Eq. 9, actual SOC at time t in period d is: # SOC(t) = SOC_boundary[d] * (1 - loss)^t_within_period + charge_state(t) # where t_within_period is hours from period start (accounts for self-discharge decay). + n_original_timesteps_extra = len(original_timesteps_extra) soc_boundary_vars = [name for name in reduced_solution.data_vars if name.endswith('|SOC_boundary')] for soc_boundary_name in soc_boundary_vars: storage_name = soc_boundary_name.rsplit('|', 1)[0] @@ -1309,24 +1334,31 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: soc_boundary = reduced_solution[soc_boundary_name] expanded_charge_state = expanded_fs._solution[charge_state_name] - # Map each original timestep to its original period index - original_cluster_indices = np.arange(n_original_timesteps) // timesteps_per_cluster + # Map each original timestep (including extra) to its original period index + # The extra timestep belongs to the last period + original_cluster_indices = np.minimum( + np.arange(n_original_timesteps_extra) // timesteps_per_cluster, + n_original_clusters - 1, + ) # Select SOC_boundary for each timestep (boundary[d] for period d) # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_clusters-1 soc_boundary_per_timestep = soc_boundary.isel( cluster_boundary=xr.DataArray(original_cluster_indices, dims=['time']) ) - soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps) + soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps_extra) # Apply self-discharge decay to SOC_boundary based on time within period # Get the storage's relative_loss_per_hour from the clustered flow system storage = self._fs.storages.get(storage_name) if storage is not None: # Time within period for each timestep (0, 1, 2, ..., timesteps_per_cluster-1, 0, 1, ...) - time_within_period = np.arange(n_original_timesteps) % timesteps_per_cluster + # The extra timestep is at index timesteps_per_cluster (one past the last within-cluster index) + time_within_period = np.arange(n_original_timesteps_extra) % timesteps_per_cluster + # The extra timestep gets the correct decay (timesteps_per_cluster) + time_within_period[-1] = timesteps_per_cluster time_within_period_da = xr.DataArray( - time_within_period, dims=['time'], coords={'time': original_timesteps} + time_within_period, dims=['time'], coords={'time': original_timesteps_extra} ) # Decay factor: (1 - loss)^t, using mean loss over time # Keep as DataArray to respect per-period/scenario values @@ -1342,7 +1374,6 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_fs._solution[charge_state_name] = combined_charge_state.assign_attrs(expanded_charge_state.attrs) n_combinations = len(periods) * len(scenarios) - n_original_clusters = cluster_structure.n_original_clusters logger.info( f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps ' f'({n_clusters} clusters' From 8ffd18587a6940a69e9beee4740a7391fa739096 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 13:18:47 +0100 Subject: [PATCH 08/30] - 'variable' is treated as a special valid facet value (since it exists in the melted DataFrame from data_var names, not as a dimension) - When facet_row='variable' or facet_col='variable' is passed, it's passed through directly - In line(), when faceting by variable, it's not also used for color (avoids double encoding) --- flixopt/dataset_plot_accessor.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index fc38f730b..270293165 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -34,7 +34,11 @@ def _resolve_auto_facets( animation_frame: str | Literal['auto'] | None = None, exclude_dims: set[str] | None = None, ) -> tuple[str | None, str | None, str | None]: - """Assign 'auto' facet slots from available dims using CONFIG priority lists.""" + """Assign 'auto' facet slots from available dims using CONFIG priority lists. + + Special handling for 'variable': exists in melted DataFrame (from data_var names), + not as a dimension, so it's always valid as an explicit facet request. + """ # Get available extra dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} @@ -50,9 +54,10 @@ def _resolve_auto_facets( results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used + # 'variable' is special - exists in melted df from data_var names, not in ds.dims for slot_name, value in slots.items(): if value is not None and value != 'auto': - if value in available and value not in used: + if value == 'variable' or (value in available and value not in used): used.add(value) results[slot_name] = value @@ -325,8 +330,13 @@ def line( 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: + # Only color by variable if it's not used for faceting or x-axis (and user didn't override) + if ( + x_col != 'variable' + and actual_facet_col != 'variable' + and actual_facet_row != 'variable' + and 'color' not in px_kwargs + ): fig_kwargs['color'] = 'variable' fig_kwargs['color_discrete_map'] = color_map if xlabel: From 57c9cb1aace43618b79599bf173b08c57aae35c3 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 13:51:42 +0100 Subject: [PATCH 09/30] Add variable and color to auto resolving in fxplot --- flixopt/config.py | 6 +- flixopt/dataset_plot_accessor.py | 146 ++++++++++++++++++++----------- 2 files changed, 96 insertions(+), 56 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index 602652252..9793f9ba2 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -164,9 +164,9 @@ def format(self, record): 'default_sequential_colorscale': 'turbo', 'default_qualitative_colorscale': 'plotly', 'default_line_shape': 'hv', - 'extra_dim_priority': ('cluster', 'period', 'scenario'), - 'dim_slot_priority': ('facet_col', 'facet_row', 'animation_frame'), - 'x_dim_priority': ('time', 'duration', 'duration_pct', 'period', 'scenario', 'cluster'), + 'extra_dim_priority': ('variable', 'cluster', 'period', 'scenario'), + 'dim_slot_priority': ('color', 'facet_col', 'facet_row', 'animation_frame'), + 'x_dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'period', 'scenario', 'cluster'), } ), 'solving': MappingProxyType( diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 270293165..403a22a46 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -13,14 +13,25 @@ from .config import CONFIG -def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str: - """Select x-axis dim from priority list, or 'variable' for scalar data.""" +def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: + """Select x-axis dim from priority list, or 'variable' for scalar data. + + Args: + dims: List of available dimensions. + n_data_vars: Number of data variables (for 'variable' availability). + x: Explicit x-axis choice or 'auto'. + """ if x and x != 'auto': return x + # 'variable' is available when there are multiple data_vars + available = set(dims) + if n_data_vars > 1: + available.add('variable') + # Check priority list first for dim in CONFIG.Plotting.x_dim_priority: - if dim in dims: + if dim in available: return dim # Fallback to first available dimension, or 'variable' for scalar data @@ -29,35 +40,47 @@ def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str def _resolve_auto_facets( ds: xr.Dataset, + color: str | Literal['auto'] | None, facet_col: str | Literal['auto'] | None, facet_row: str | Literal['auto'] | None, animation_frame: str | Literal['auto'] | None = None, exclude_dims: set[str] | None = None, -) -> tuple[str | None, str | None, str | None]: +) -> tuple[str | None, str | None, str | None, str | None]: """Assign 'auto' facet slots from available dims using CONFIG priority lists. - Special handling for 'variable': exists in melted DataFrame (from data_var names), - not as a dimension, so it's always valid as an explicit facet request. + 'variable' is treated like a dimension - available when len(data_vars) > 1. + It exists in the melted DataFrame from data_var names, not in ds.dims. + + Returns: + Tuple of (color, facet_col, facet_row, animation_frame). """ # Get available extra dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} + # 'variable' is available when there are multiple data_vars + if len(ds.data_vars) > 1: + available.add('variable') extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] used: set[str] = set() # Map slot names to their input values slots = { + 'color': color, 'facet_col': facet_col, 'facet_row': facet_row, 'animation_frame': animation_frame, } - results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} + results: dict[str, str | None] = { + 'color': None, + 'facet_col': None, + 'facet_row': None, + 'animation_frame': None, + } # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used - # 'variable' is special - exists in melted df from data_var names, not in ds.dims for slot_name, value in slots.items(): if value is not None and value != 'auto': - if value == 'variable' or (value in available and value not in used): + if value in available and value not in used: used.add(value) results[slot_name] = value @@ -70,7 +93,7 @@ def _resolve_auto_facets( used.add(next_dim) results[slot_name] = next_dim - return results['facet_col'], results['facet_row'], results['animation_frame'] + return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: @@ -125,6 +148,7 @@ def bar( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -139,6 +163,8 @@ def bar( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -154,17 +180,20 @@ def bar( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -174,9 +203,8 @@ def bar( 'title': title, 'barmode': 'group', } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -198,6 +226,7 @@ def stacked_bar( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -215,6 +244,8 @@ def stacked_bar( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -230,17 +261,20 @@ def stacked_bar( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -249,9 +283,8 @@ def stacked_bar( 'y': 'value', 'title': title, } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -276,6 +309,7 @@ def line( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -293,6 +327,8 @@ def line( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -310,17 +346,20 @@ def line( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -330,14 +369,8 @@ def line( 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - # Only color by variable if it's not used for faceting or x-axis (and user didn't override) - if ( - x_col != 'variable' - and actual_facet_col != 'variable' - and actual_facet_row != 'variable' - and 'color' not in px_kwargs - ): - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -359,6 +392,7 @@ def area( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -374,6 +408,8 @@ def area( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -390,17 +426,20 @@ def area( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -410,9 +449,8 @@ def area( 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -477,7 +515,7 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, facet_col, None, animation_frame) + _, actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, None, facet_col, None, animation_frame) imshow_args: dict[str, Any] = { 'img': da, @@ -535,8 +573,8 @@ def scatter( if df.empty: return go.Figure() - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame + _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, None, facet_col, facet_row, animation_frame ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -619,8 +657,8 @@ def pie( if df.empty: return go.Figure() - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame + _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, None, facet_col, facet_row, animation_frame ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -882,7 +920,9 @@ def heatmap( # Use Dataset for facet resolution ds_for_resolution = da.to_dataset(name='_temp') - actual_facet_col, _, actual_anim = _resolve_auto_facets(ds_for_resolution, facet_col, None, animation_frame) + _, actual_facet_col, _, actual_anim = _resolve_auto_facets( + ds_for_resolution, None, facet_col, None, animation_frame + ) imshow_args: dict[str, Any] = { 'img': da, From df1fac1e30de08bea3eba6fa5368968aca6efb6c Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 13:53:58 +0100 Subject: [PATCH 10/30] Added 'variable' to both priority lists and updated the logic to treat it consistently: flixopt/config.py: 'extra_dim_priority': ('variable', 'cluster', 'period', 'scenario'), 'x_dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'period', 'scenario', 'cluster'), flixopt/dataset_plot_accessor.py: - _get_x_dim: Now takes n_data_vars parameter; 'variable' is available when > 1 - _resolve_auto_facets: 'variable' is available when len(data_vars) > 1 and respects exclude_dims Behavior: - 'variable' is treated like any other dimension in the priority system - Only available when there are multiple data_vars - Properly excluded when already used (e.g., for x-axis) --- flixopt/clustering/base.py | 181 ++++++++++++------------------- flixopt/dataset_plot_accessor.py | 4 +- 2 files changed, 69 insertions(+), 116 deletions(-) diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 9c900593a..f914514b4 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -197,20 +197,20 @@ def get_cluster_weight_per_timestep(self) -> xr.DataArray: name='cluster_weight', ) - def plot(self, show: bool | None = None) -> PlotResult: + def plot(self, colors: str | list[str] | None = None, show: bool | None = None) -> PlotResult: """Plot cluster assignment visualization. Shows which cluster each original period belongs to, and the number of occurrences per cluster. Args: + colors: Colorscale name (str) or list of colors. + Defaults to CONFIG.Plotting.default_sequential_colorscale. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. Returns: PlotResult containing the figure and underlying data. """ - import plotly.express as px - from ..config import CONFIG from ..plot_result import PlotResult @@ -218,27 +218,24 @@ def plot(self, show: bool | None = None) -> PlotResult: int(self.n_clusters) if isinstance(self.n_clusters, (int, np.integer)) else int(self.n_clusters.values) ) - # Create DataFrame for plotting - import pandas as pd - cluster_order = self.get_cluster_order_for_slice() - df = pd.DataFrame( - { - 'Original Period': range(1, len(cluster_order) + 1), - 'Cluster': cluster_order, - } + + # Build DataArray for fxplot heatmap + cluster_da = xr.DataArray( + cluster_order.reshape(1, -1), + dims=['y', 'original_cluster'], + coords={'y': ['Cluster'], 'original_cluster': range(1, len(cluster_order) + 1)}, + name='cluster_assignment', ) - # Bar chart showing cluster assignment - fig = px.bar( - df, - x='Original Period', - y=[1] * len(df), - color='Cluster', - color_continuous_scale='Viridis', + # Use fxplot.heatmap for smart defaults + colorscale = colors or CONFIG.Plotting.default_sequential_colorscale + fig = cluster_da.fxplot.heatmap( + colors=colorscale, title=f'Cluster Assignment ({self.n_original_clusters} periods → {n_clusters} clusters)', ) - fig.update_layout(yaxis_visible=False, coloraxis_colorbar_title='Cluster') + fig.update_yaxes(showticklabels=False) + fig.update_coloraxes(colorbar_title='Cluster') # Build data for PlotResult data = xr.Dataset( @@ -585,8 +582,8 @@ def compare( *, select: SelectType | None = None, colors: ColorType | None = None, - facet_col: str | None = 'period', - facet_row: str | None = 'scenario', + facet_col: str | None = 'auto', + facet_row: str | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -600,8 +597,10 @@ def compare( or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). - facet_col: Dimension for subplot columns (default: 'period'). - facet_row: Dimension for subplot rows (default: 'scenario'). + facet_col: Dimension for subplot columns. 'auto' uses CONFIG priority. + Use 'variable' to create separate columns per variable. + facet_row: Dimension for subplot rows. 'auto' uses CONFIG priority. + Use 'variable' to create separate rows per variable. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. **plotly_kwargs: Additional arguments passed to plotly. @@ -610,9 +609,7 @@ def compare( PlotResult containing the comparison figure and underlying data. """ import pandas as pd - import plotly.express as px - from ..color_processing import process_colors from ..config import CONFIG from ..plot_result import PlotResult from ..statistics_accessor import _apply_selection @@ -626,7 +623,7 @@ def compare( resolved_variables = self._resolve_variables(variables) - # Build Dataset with 'representation' dimension for Original/Clustered + # Build Dataset with variables as data_vars data_vars = {} for var in resolved_variables: original = result.original_data[var] @@ -650,54 +647,34 @@ def compare( { var: xr.DataArray( [sorted_vars[(var, r)] for r in ['Original', 'Clustered']], - dims=['representation', 'rank'], - coords={'representation': ['Original', 'Clustered'], 'rank': range(n)}, + dims=['representation', 'duration'], + coords={'representation': ['Original', 'Clustered'], 'duration': range(n)}, ) for var in resolved_variables } ) - # Resolve facets (only for timeseries) - actual_facet_col = facet_col if kind == 'timeseries' and facet_col in ds.dims else None - actual_facet_row = facet_row if kind == 'timeseries' and facet_row in ds.dims else None - - # Convert to long-form DataFrame - df = ds.to_dataframe().reset_index() - coord_cols = [c for c in ds.coords.keys() if c in df.columns] - df = df.melt(id_vars=coord_cols, var_name='variable', value_name='value') - - variable_labels = df['variable'].unique().tolist() - color_map = process_colors(colors, variable_labels, CONFIG.Plotting.default_qualitative_colorscale) - - # Set x-axis and title based on kind - x_col = 'time' if kind == 'timeseries' else 'rank' + # Set title based on kind if kind == 'timeseries': title = ( 'Original vs Clustered' if len(resolved_variables) > 1 else f'Original vs Clustered: {resolved_variables[0]}' ) - labels = {} else: title = 'Duration Curve' if len(resolved_variables) > 1 else f'Duration Curve: {resolved_variables[0]}' - labels = {'rank': 'Hours (sorted)', 'value': 'Value'} - fig = px.line( - df, - x=x_col, - y='value', - color='variable', - line_dash='representation', - facet_col=actual_facet_col, - facet_row=actual_facet_row, + # Use fxplot for smart defaults with line_dash for representation + fig = ds.fxplot.line( + colors=colors, title=title, - labels=labels, - color_discrete_map=color_map, + facet_col=facet_col, + facet_row=facet_row, + line_dash='representation', **plotly_kwargs, ) - if actual_facet_row or actual_facet_col: - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + fig.update_yaxes(matches=None) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) plot_result = PlotResult(data=ds, figure=fig) @@ -743,8 +720,8 @@ def heatmap( *, select: SelectType | None = None, colors: str | list[str] | None = None, - facet_col: str | None = 'period', - animation_frame: str | None = 'scenario', + facet_col: str | None = 'auto', + animation_frame: str | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -762,8 +739,8 @@ def heatmap( colors: Colorscale name (str) or list of colors for heatmap coloring. Dicts are not supported for heatmaps. Defaults to CONFIG.Plotting.default_sequential_colorscale. - facet_col: Dimension to facet on columns (default: 'period'). - animation_frame: Dimension for animation slider (default: 'scenario'). + facet_col: Dimension to facet on columns. 'auto' uses CONFIG priority. + animation_frame: Dimension for animation slider. 'auto' uses CONFIG priority. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. **plotly_kwargs: Additional arguments passed to plotly. @@ -773,7 +750,6 @@ def heatmap( The data has 'cluster' variable with time dimension, matching original timesteps. """ import pandas as pd - import plotly.express as px from ..config import CONFIG from ..plot_result import PlotResult @@ -833,34 +809,24 @@ def heatmap( else: cluster_da = cluster_slices[(None, None)] - # Resolve facet_col and animation_frame - only use if dimension exists - actual_facet_col = facet_col if facet_col and facet_col in cluster_da.dims else None - actual_animation = animation_frame if animation_frame and animation_frame in cluster_da.dims else None - # Add dummy y dimension for heatmap visualization (single row) heatmap_da = cluster_da.expand_dims('y', axis=-1) heatmap_da = heatmap_da.assign_coords(y=['Cluster']) + heatmap_da.name = 'cluster_assignment' - colorscale = colors or CONFIG.Plotting.default_sequential_colorscale - - # Use px.imshow with xr.DataArray - fig = px.imshow( - heatmap_da, - color_continuous_scale=colorscale, - facet_col=actual_facet_col, - animation_frame=actual_animation, + # Use fxplot.heatmap for smart defaults + fig = heatmap_da.fxplot.heatmap( + colors=colors, title='Cluster Assignments', - labels={'time': 'Time', 'color': 'Cluster'}, + facet_col=facet_col, + animation_frame=animation_frame, aspect='auto', **plotly_kwargs, ) - # Clean up facet labels - if actual_facet_col: - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) - - # Hide y-axis since it's just a single row + # Clean up: hide y-axis since it's just a single row fig.update_yaxes(showticklabels=False) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) # Data is exactly what we plotted (without dummy y dimension) cluster_da.name = 'cluster' @@ -880,21 +846,21 @@ def clusters( *, select: SelectType | None = None, colors: ColorType | None = None, - facet_col_wrap: int | None = None, + facet_cols: int | None = None, show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: """Plot each cluster's typical period profile. - Shows each cluster as a separate faceted subplot. Useful for - understanding what each cluster represents. + Shows each cluster as a separate faceted subplot with all variables + colored differently. Useful for understanding what each cluster represents. Args: variables: Variable(s) to plot. Can be a string, list of strings, or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). - facet_col_wrap: Max columns before wrapping facets. + facet_cols: Max columns before wrapping facets. Defaults to CONFIG.Plotting.default_facet_cols. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. @@ -903,10 +869,6 @@ def clusters( Returns: PlotResult containing the figure and underlying data. """ - import pandas as pd - import plotly.express as px - - from ..color_processing import process_colors from ..config import CONFIG from ..plot_result import PlotResult from ..statistics_accessor import _apply_selection @@ -929,45 +891,36 @@ def clusters( n_clusters = int(cs.n_clusters) if isinstance(cs.n_clusters, (int, np.integer)) else int(cs.n_clusters.values) timesteps_per_cluster = cs.timesteps_per_cluster - # Build long-form DataFrame with cluster labels including occurrence counts - rows = [] + # Build Dataset with cluster dimension, using labels with occurrence counts + cluster_labels = [ + f'Cluster {c} (×{int(cs.cluster_occurrences.sel(cluster=c).values)})' for c in range(n_clusters) + ] + data_vars = {} for var in resolved_variables: data = aggregated_data[var].values data_by_cluster = data.reshape(n_clusters, timesteps_per_cluster) data_vars[var] = xr.DataArray( data_by_cluster, - dims=['cluster', 'timestep'], - coords={'cluster': range(n_clusters), 'timestep': range(timesteps_per_cluster)}, + dims=['cluster', 'time'], + coords={'cluster': cluster_labels, 'time': range(timesteps_per_cluster)}, ) - for c in range(n_clusters): - occurrence = int(cs.cluster_occurrences.sel(cluster=c).values) - label = f'Cluster {c} (×{occurrence})' - for t in range(timesteps_per_cluster): - rows.append({'cluster': label, 'timestep': t, 'value': data_by_cluster[c, t], 'variable': var}) - df = pd.DataFrame(rows) - - cluster_labels = df['cluster'].unique().tolist() - color_map = process_colors(colors, cluster_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_col_wrap or CONFIG.Plotting.default_facet_cols + + ds = xr.Dataset(data_vars) title = 'Clusters' if len(resolved_variables) > 1 else f'Clusters: {resolved_variables[0]}' - fig = px.line( - df, - x='timestep', - y='value', - facet_col='cluster', - facet_row='variable' if len(resolved_variables) > 1 else None, - facet_col_wrap=facet_col_wrap if len(resolved_variables) == 1 else None, + # Use fxplot for smart defaults + fig = ds.fxplot.line( + colors=colors, title=title, - color_discrete_map=color_map, + facet_col='cluster', + facet_cols=facet_cols, **plotly_kwargs, ) - fig.update_layout(showlegend=False) - if len(resolved_variables) > 1: - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + fig.update_yaxes(matches=None) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + # Include occurrences in result data data_vars['occurrences'] = cs.cluster_occurrences result_data = xr.Dataset(data_vars) plot_result = PlotResult(data=result_data, figure=fig) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 403a22a46..73b20b436 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -57,8 +57,8 @@ def _resolve_auto_facets( # Get available extra dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} - # 'variable' is available when there are multiple data_vars - if len(ds.data_vars) > 1: + # 'variable' is available when there are multiple data_vars (and not excluded) + if len(ds.data_vars) > 1 and 'variable' not in exclude: available.add('variable') extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] used: set[str] = set() From 829c34244b6edce8b446e4ece70cfa3c09e6385d Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 14:10:57 +0100 Subject: [PATCH 11/30] Improve plotting, especially for clustering --- docs/notebooks/08c-clustering.ipynb | 272 ++++++++++++++++++---------- flixopt/clustering/base.py | 26 ++- tests/test_cluster_reduce_expand.py | 8 +- 3 files changed, 202 insertions(+), 104 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index 0e9cda7b7..d07512cac 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -28,10 +28,8 @@ "source": [ "import timeit\n", "\n", - "import numpy as np\n", "import pandas as pd\n", - "import plotly.graph_objects as go\n", - "from plotly.subplots import make_subplots\n", + "import xarray as xr\n", "\n", "import flixopt as fx\n", "\n", @@ -73,18 +71,13 @@ "outputs": [], "source": [ "# Visualize input data\n", - "heat_demand = flow_system.components['HeatDemand'].inputs[0].fixed_relative_profile\n", - "electricity_price = flow_system.components['GridBuy'].outputs[0].effects_per_flow_hour['costs']\n", - "\n", - "fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1)\n", - "fig.add_trace(go.Scatter(x=timesteps, y=heat_demand.values, name='Heat Demand', line=dict(width=0.5)), row=1, col=1)\n", - "fig.add_trace(\n", - " go.Scatter(x=timesteps, y=electricity_price.values, name='Electricity Price', line=dict(width=0.5)), row=2, col=1\n", + "input_ds = xr.Dataset(\n", + " {\n", + " 'Heat Demand': flow_system.components['HeatDemand'].inputs[0].fixed_relative_profile,\n", + " 'Electricity Price': flow_system.components['GridBuy'].outputs[0].effects_per_flow_hour['costs'],\n", + " }\n", ")\n", - "fig.update_layout(height=400, title='One Month of Input Data')\n", - "fig.update_yaxes(title_text='Heat Demand [MW]', row=1, col=1)\n", - "fig.update_yaxes(title_text='El. Price [€/MWh]', row=2, col=1)\n", - "fig.show()" + "input_ds.fxplot.line(facet_row='variable', title='One Month of Input Data')" ] }, { @@ -154,11 +147,16 @@ " n_clusters=8, # 8 typical days\n", " cluster_duration='1D', # Daily clustering\n", " time_series_for_high_peaks=peak_series, # Capture peak demand day\n", + " random_state=42, # Reproducible results\n", ")\n", "\n", "time_clustering = timeit.default_timer() - start\n", - "print(f'Clustering time: {time_clustering:.1f} seconds')\n", - "print(f'Reduced: {len(flow_system.timesteps)} → {len(fs_clustered.timesteps)} timesteps')" + "\n", + "print(\n", + " f'Clustering: {len(flow_system.timesteps)} → {len(fs_clustered.timesteps) * len(fs_clustered.clusters)} timesteps'\n", + ")\n", + "print(f' Clusters: {len(fs_clustered.clusters)}')\n", + "print(f' Time: {time_clustering:.2f}s')" ] }, { @@ -188,7 +186,7 @@ "source": [ "## Understanding the Clustering\n", "\n", - "The clustering algorithm groups similar days together. Let's inspect the cluster structure:" + "The clustering algorithm groups similar days together. Access all metadata via `fs.clustering`:" ] }, { @@ -198,26 +196,110 @@ "metadata": {}, "outputs": [], "source": [ - "# Show clustering info\n", - "info = fs_clustered.clustering\n", - "cs = info.result.cluster_structure\n", - "print('Clustering Configuration:')\n", - "print(f' Number of typical periods: {cs.n_clusters}')\n", - "print(f' Timesteps per period: {cs.timesteps_per_cluster}')\n", - "print(f' Total reduced timesteps: {cs.n_clusters * cs.timesteps_per_cluster}')\n", - "print(f' Cluster order (first 10 days): {cs.cluster_order.values[:10]}...')\n", + "# Access clustering metadata directly\n", + "clustering = fs_clustered.clustering\n", + "clustering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "# Key properties\n", + "print(f'Clusters: {clustering.n_clusters}')\n", + "print(f'Original segments (days): {clustering.n_original_clusters}')\n", + "print(f'Timesteps per cluster: {clustering.timesteps_per_cluster}')\n", + "print(f'\\nCluster occurrences: {clustering.occurrences.values}')\n", + "print(f'Cluster order: {clustering.cluster_order.values}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "# Quality metrics - how well do the clusters represent the original data?\n", + "# Lower RMSE/MAE = better representation\n", + "clustering.metrics.to_dataframe().style.format('{:.3f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Visual comparison: original vs clustered time series\n", + "clustering.plot.compare()" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Advanced Clustering Options\n", + "\n", + "The `cluster()` method exposes many parameters for fine-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# Try different clustering algorithms\n", + "fs_hierarchical = flow_system.transform.cluster(\n", + " n_clusters=8,\n", + " cluster_duration='1D',\n", + " cluster_method='hierarchical', # Alternative: 'k_means' (default), 'k_medoids', 'averaging'\n", + " random_state=42,\n", + ")\n", "\n", - "# Show how many times each cluster appears\n", - "cluster_order = cs.cluster_order.values\n", - "unique, counts = np.unique(cluster_order, return_counts=True)\n", - "print('\\nCluster occurrences:')\n", - "for cluster_id, count in zip(unique, counts, strict=False):\n", - " print(f' Cluster {cluster_id}: {count} days')" + "# Compare cluster assignments between algorithms\n", + "print('k_means clusters: ', fs_clustered.clustering.cluster_order.values)\n", + "print('hierarchical clusters:', fs_hierarchical.clustering.cluster_order.values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare RMSE between algorithms\n", + "print('Quality comparison (RMSE for HeatDemand):')\n", + "print(\n", + " f' k_means: {float(fs_clustered.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + ")\n", + "print(\n", + " f' hierarchical: {float(fs_hierarchical.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize cluster structure with heatmap\n", + "clustering.plot.heatmap()" ] }, { "cell_type": "markdown", - "id": "12", + "id": "19", "metadata": {}, "source": [ "## Method 3: Two-Stage Workflow (Recommended)\n", @@ -235,7 +317,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -256,7 +338,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -279,7 +361,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "22", "metadata": {}, "source": [ "## Compare Results" @@ -288,7 +370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -337,7 +419,7 @@ }, { "cell_type": "markdown", - "id": "17", + "id": "24", "metadata": {}, "source": [ "## Expand Solution to Full Resolution\n", @@ -349,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -363,34 +445,29 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "26", "metadata": {}, "outputs": [], "source": [ - "# Compare heat balance: Full vs Expanded\n", - "fig = make_subplots(rows=2, cols=1, shared_xaxes=True, subplot_titles=['Full Optimization', 'Expanded from Clustering'])\n", + "# Compare heat production: Full vs Expanded\n", + "heat_flows = ['CHP(Q_th)|flow_rate', 'Boiler(Q_th)|flow_rate']\n", "\n", - "# Full\n", - "for var in ['CHP(Q_th)', 'Boiler(Q_th)']:\n", - " values = fs_full.solution[f'{var}|flow_rate'].values\n", - " fig.add_trace(go.Scatter(x=fs_full.timesteps, y=values, name=var, legendgroup=var, showlegend=True), row=1, col=1)\n", - "\n", - "# Expanded\n", - "for var in ['CHP(Q_th)', 'Boiler(Q_th)']:\n", - " values = fs_expanded.solution[f'{var}|flow_rate'].values\n", - " fig.add_trace(\n", - " go.Scatter(x=fs_expanded.timesteps, y=values, name=var, legendgroup=var, showlegend=False), row=2, col=1\n", - " )\n", + "# Create comparison dataset\n", + "comparison_ds = xr.Dataset(\n", + " {\n", + " name.replace('|flow_rate', ''): xr.concat(\n", + " [fs_full.solution[name], fs_expanded.solution[name]], dim=pd.Index(['Full', 'Expanded'], name='method')\n", + " )\n", + " for name in heat_flows\n", + " }\n", + ")\n", "\n", - "fig.update_layout(height=500, title='Heat Production Comparison')\n", - "fig.update_yaxes(title_text='MW', row=1, col=1)\n", - "fig.update_yaxes(title_text='MW', row=2, col=1)\n", - "fig.show()" + "comparison_ds.fxplot.line(facet_col='variable', color='method', title='Heat Production Comparison')" ] }, { "cell_type": "markdown", - "id": "20", + "id": "27", "metadata": {}, "source": [ "## Visualize Clustered Heat Balance" @@ -399,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -409,33 +486,55 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "29", "metadata": {}, "outputs": [], "source": [ - "fs_expanded.statistics.plot.storage('Storage')" + "fs_expanded.statistics.plot.storage('Storage').data.to_dataframe()" ] }, { "cell_type": "markdown", - "id": "23", + "id": "30", "metadata": {}, "source": [ "## API Reference\n", "\n", "### `transform.cluster()` Parameters\n", "\n", - "| Parameter | Type | Description |\n", - "|-----------|------|-------------|\n", - "| `n_clusters` | `int` | Number of typical periods (e.g., 8 typical days) |\n", - "| `cluster_duration` | `str \\| float` | Duration per cluster ('1D', '24h') or hours |\n", - "| `weights` | `dict[str, float]` | Optional weights for time series in clustering |\n", - "| `time_series_for_high_peaks` | `list[str]` | **Essential**: Force inclusion of peak periods |\n", - "| `time_series_for_low_peaks` | `list[str]` | Force inclusion of minimum periods |\n", + "| Parameter | Type | Default | Description |\n", + "|-----------|------|---------|-------------|\n", + "| `n_clusters` | `int` | - | Number of typical periods (e.g., 8 typical days) |\n", + "| `cluster_duration` | `str \\| float` | - | Duration per cluster ('1D', '24h') or hours |\n", + "| `weights` | `dict[str, float]` | None | Optional weights for time series in clustering |\n", + "| `time_series_for_high_peaks` | `list[str]` | None | **Essential**: Force inclusion of peak periods |\n", + "| `time_series_for_low_peaks` | `list[str]` | None | Force inclusion of minimum periods |\n", + "| `cluster_method` | `str` | 'k_means' | Algorithm: 'k_means', 'hierarchical', 'k_medoids', 'k_maxoids', 'averaging' |\n", + "| `representation_method` | `str` | 'meanRepresentation' | 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' |\n", + "| `extreme_period_method` | `str` | 'new_cluster_center' | How peaks are integrated: 'None', 'append', 'new_cluster_center', 'replace_cluster_center' |\n", + "| `rescale_cluster_periods` | `bool` | True | Rescale clusters to match original means |\n", + "| `random_state` | `int` | None | Random seed for reproducibility |\n", + "| `predef_cluster_order` | `array` | None | Manual cluster assignments |\n", + "| `**tsam_kwargs` | - | - | Additional tsam parameters |\n", + "\n", + "### Clustering Object Properties\n", + "\n", + "After clustering, access metadata via `fs.clustering`:\n", + "\n", + "| Property | Description |\n", + "|----------|-------------|\n", + "| `n_clusters` | Number of clusters |\n", + "| `n_original_clusters` | Number of original time segments (e.g., 365 days) |\n", + "| `timesteps_per_cluster` | Timesteps in each cluster (e.g., 24 for daily) |\n", + "| `cluster_order` | xr.DataArray mapping original segment → cluster ID |\n", + "| `occurrences` | How many original segments each cluster represents |\n", + "| `metrics` | xr.Dataset with RMSE, MAE per time series |\n", + "| `plot.compare()` | Compare original vs clustered time series |\n", + "| `plot.heatmap()` | Visualize cluster structure |\n", "\n", "### Storage Behavior\n", "\n", - "Each `Storage` component has a `cluster_storage_mode` parameter that controls how it behaves during clustering:\n", + "Each `Storage` component has a `cluster_mode` parameter:\n", "\n", "| Mode | Description |\n", "|------|-------------|\n", @@ -444,37 +543,12 @@ "| `'cyclic'` | Each cluster is independent but cyclic (start = end) |\n", "| `'independent'` | Each cluster is independent, free start/end |\n", "\n", - "For a detailed comparison of storage modes, see [08c2-clustering-storage-modes](08c2-clustering-storage-modes.ipynb).\n", - "\n", - "### Peak Forcing Format\n", - "\n", - "```python\n", - "time_series_for_high_peaks = ['ComponentName(FlowName)|fixed_relative_profile']\n", - "```\n", - "\n", - "### Recommended Workflow\n", - "\n", - "```python\n", - "# Stage 1: Fast sizing\n", - "fs_sizing = flow_system.transform.cluster(\n", - " n_clusters=8,\n", - " cluster_duration='1D',\n", - " time_series_for_high_peaks=['Demand(Flow)|fixed_relative_profile'],\n", - ")\n", - "fs_sizing.optimize(solver)\n", - "\n", - "# Apply safety margin\n", - "sizes = {k: v.item() * 1.05 for k, v in fs_sizing.statistics.sizes.items()}\n", - "\n", - "# Stage 2: Accurate dispatch\n", - "fs_dispatch = flow_system.transform.fix_sizes(sizes)\n", - "fs_dispatch.optimize(solver)\n", - "```" + "For a detailed comparison of storage modes, see [08c2-clustering-storage-modes](08c2-clustering-storage-modes.ipynb)." ] }, { "cell_type": "markdown", - "id": "24", + "id": "31", "metadata": {}, "source": [ "## Summary\n", @@ -485,13 +559,17 @@ "- Apply **peak forcing** to capture extreme demand days\n", "- Use **two-stage optimization** for fast yet accurate investment decisions\n", "- **Expand solutions** back to full resolution with `expand_solution()`\n", + "- Access **clustering metadata** via `fs.clustering` (metrics, cluster_order, occurrences)\n", + "- Use **advanced options** like different algorithms and reproducible random states\n", "\n", "### Key Takeaways\n", "\n", "1. **Always use peak forcing** (`time_series_for_high_peaks`) for demand time series\n", "2. **Add safety margin** (5-10%) when fixing sizes from clustering\n", "3. **Two-stage is recommended**: clustering for sizing, full resolution for dispatch\n", - "4. **Storage handling** is configurable via `storage_mode`\n", + "4. **Storage handling** is configurable via `cluster_mode`\n", + "5. **Use `random_state`** for reproducible results\n", + "6. **Check metrics** to evaluate clustering quality\n", "\n", "### Next Steps\n", "\n", diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index f914514b4..ab9590aae 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -582,6 +582,8 @@ def compare( *, select: SelectType | None = None, colors: ColorType | None = None, + color: str | None = 'auto', + line_dash: str | None = 'representation', facet_col: str | None = 'auto', facet_row: str | None = 'auto', show: bool | None = None, @@ -597,6 +599,10 @@ def compare( or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). + color: Dimension for line colors. 'auto' uses CONFIG priority (typically 'variable'). + Use 'representation' to color by Original/Clustered instead of line_dash. + line_dash: Dimension for line dash styles. Defaults to 'representation'. + Set to None to disable line dash differentiation. facet_col: Dimension for subplot columns. 'auto' uses CONFIG priority. Use 'variable' to create separate columns per variable. facet_row: Dimension for subplot rows. 'auto' uses CONFIG priority. @@ -664,13 +670,20 @@ def compare( else: title = 'Duration Curve' if len(resolved_variables) > 1 else f'Duration Curve: {resolved_variables[0]}' - # Use fxplot for smart defaults with line_dash for representation + # Use fxplot for smart defaults + line_kwargs = {} + if line_dash is not None: + line_kwargs['line_dash'] = line_dash + if line_dash == 'representation': + line_kwargs['line_dash_map'] = {'Original': 'dot', 'Clustered': 'solid'} + fig = ds.fxplot.line( colors=colors, + color=color, title=title, facet_col=facet_col, facet_row=facet_row, - line_dash='representation', + **line_kwargs, **plotly_kwargs, ) fig.update_yaxes(matches=None) @@ -846,6 +859,8 @@ def clusters( *, select: SelectType | None = None, colors: ColorType | None = None, + color: str | None = 'auto', + facet_col: str | None = 'cluster', facet_cols: int | None = None, show: bool | None = None, **plotly_kwargs: Any, @@ -860,6 +875,10 @@ def clusters( or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). + color: Dimension for line colors. 'auto' uses CONFIG priority (typically 'variable'). + Use 'cluster' to color by cluster instead of faceting. + facet_col: Dimension for subplot columns. Defaults to 'cluster'. + Use 'variable' to facet by variable instead. facet_cols: Max columns before wrapping facets. Defaults to CONFIG.Plotting.default_facet_cols. show: Whether to display the figure. @@ -912,8 +931,9 @@ def clusters( # Use fxplot for smart defaults fig = ds.fxplot.line( colors=colors, + color=color, title=title, - facet_col='cluster', + facet_col=facet_col, facet_cols=facet_cols, **plotly_kwargs, ) diff --git a/tests/test_cluster_reduce_expand.py b/tests/test_cluster_reduce_expand.py index b54eeb56e..4059470ee 100644 --- a/tests/test_cluster_reduce_expand.py +++ b/tests/test_cluster_reduce_expand.py @@ -167,7 +167,7 @@ def test_expand_solution_enables_statistics_accessor(solver_fixture, timesteps_8 # These should work without errors flow_rates = fs_expanded.statistics.flow_rates assert 'Boiler(Q_th)' in flow_rates - assert len(flow_rates['Boiler(Q_th)'].coords['time']) == 192 + assert len(flow_rates['Boiler(Q_th)'].coords['time']) == 193 # 192 + 1 extra timestep flow_hours = fs_expanded.statistics.flow_hours assert 'Boiler(Q_th)' in flow_hours @@ -321,7 +321,7 @@ def test_cluster_and_expand_with_scenarios(solver_fixture, timesteps_8_days, sce flow_var = 'Boiler(Q_th)|flow_rate' assert flow_var in fs_expanded.solution assert 'scenario' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep def test_expand_solution_maps_scenarios_independently(solver_fixture, timesteps_8_days, scenarios_2): @@ -693,7 +693,7 @@ def test_expand_solution_with_periods(self, solver_fixture, timesteps_8_days, pe # Solution should have period dimension flow_var = 'Boiler(Q_th)|flow_rate' assert 'period' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep def test_cluster_with_periods_and_scenarios(self, solver_fixture, timesteps_8_days, periods_2, scenarios_2): """Clustering should work with both periods and scenarios.""" @@ -719,7 +719,7 @@ def test_cluster_with_periods_and_scenarios(self, solver_fixture, timesteps_8_da fs_expanded = fs_clustered.transform.expand_solution() assert 'period' in fs_expanded.solution[flow_var].dims assert 'scenario' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep # ==================== Peak Selection Tests ==================== From ed80f89893be2dcc1503ea808f22a1859f96a17b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 14:15:35 +0100 Subject: [PATCH 12/30] Drop cluster index when expanding --- flixopt/transform_accessor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index b466d928a..8ec350c95 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -1271,6 +1271,8 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: # Multi-dimensional: select last cluster for each period/scenario slice last_clusters = cluster_order.isel(original_cluster=last_original_cluster_idx) extra_val = da.isel(cluster=last_clusters, time=-1) + # Drop 'cluster' coord created by advanced indexing (non-dim coord from isel) + extra_val = extra_val.drop_vars('cluster', errors='ignore') extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]]) expanded = xr.concat([expanded, extra_val], dim='time') From 3dc7eec2d238bedd9179d03854f32ef5fb58cef8 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 15:21:33 +0100 Subject: [PATCH 13/30] Fix storage expansion --- flixopt/transform_accessor.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 8ec350c95..43b15d440 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -1271,8 +1271,8 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: # Multi-dimensional: select last cluster for each period/scenario slice last_clusters = cluster_order.isel(original_cluster=last_original_cluster_idx) extra_val = da.isel(cluster=last_clusters, time=-1) - # Drop 'cluster' coord created by advanced indexing (non-dim coord from isel) - extra_val = extra_val.drop_vars('cluster', errors='ignore') + # Drop 'cluster'/'time' coords created by isel (kept as non-dim coords) + extra_val = extra_val.drop_vars(['cluster', 'time'], errors='ignore') extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]]) expanded = xr.concat([expanded, extra_val], dim='time') @@ -1363,10 +1363,15 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: time_within_period, dims=['time'], coords={'time': original_timesteps_extra} ) # Decay factor: (1 - loss)^t, using mean loss over time - # Keep as DataArray to respect per-period/scenario values loss_value = storage.relative_loss_per_hour.mean('time') if (loss_value > 0).any(): decay_da = (1 - loss_value) ** time_within_period_da + if 'cluster' in decay_da.dims: + # Map each timestep to its cluster's decay value + cluster_per_timestep = cluster_structure.cluster_order.values[original_cluster_indices] + decay_da = decay_da.isel(cluster=xr.DataArray(cluster_per_timestep, dims=['time'])).drop_vars( + 'cluster', errors='ignore' + ) soc_boundary_per_timestep = soc_boundary_per_timestep * decay_da # Combine: actual_SOC = SOC_boundary * decay + charge_state @@ -1375,6 +1380,14 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: combined_charge_state = (expanded_charge_state + soc_boundary_per_timestep).clip(min=0) expanded_fs._solution[charge_state_name] = combined_charge_state.assign_attrs(expanded_charge_state.attrs) + # Remove SOC_boundary variables - they're cluster-specific and now incorporated into charge_state + for soc_boundary_name in soc_boundary_vars: + if soc_boundary_name in expanded_fs._solution: + del expanded_fs._solution[soc_boundary_name] + # Also drop the cluster_boundary coordinate (orphaned after removing SOC_boundary) + if 'cluster_boundary' in expanded_fs._solution.coords: + expanded_fs._solution = expanded_fs._solution.drop_vars('cluster_boundary') + n_combinations = len(periods) * len(scenarios) logger.info( f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps ' From 6b0579f6bb24ac9b0e56c4d1f3b964f0017dc825 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 15:38:41 +0100 Subject: [PATCH 14/30] Improve clustering --- docs/notebooks/08c-clustering.ipynb | 72 +++++++++++++++++++++++------ flixopt/transform_accessor.py | 27 +++++------ 2 files changed, 73 insertions(+), 26 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index d07512cac..4d8b8a121 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -301,6 +301,50 @@ "cell_type": "markdown", "id": "19", "metadata": {}, + "source": [ + "### Manual Cluster Assignment\n", + "\n", + "When comparing design variants or performing sensitivity analysis, you often want to\n", + "use the **same cluster structure** across different FlowSystem configurations.\n", + "Use `predef_cluster_order` to ensure comparable results:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "# Save the cluster order from our optimized system\n", + "cluster_order = fs_clustered.clustering.cluster_order.values\n", + "print(f'Cluster order to reuse: {cluster_order}')\n", + "\n", + "# Now modify the FlowSystem (e.g., increase storage capacity limits)\n", + "flow_system_modified = flow_system.copy()\n", + "flow_system_modified.components['Storage'].capacity_in_flow_hours.maximum_size = 2000 # Larger storage option\n", + "\n", + "# Cluster with the SAME cluster structure for fair comparison\n", + "fs_modified_clustered = flow_system_modified.transform.cluster(\n", + " n_clusters=8,\n", + " cluster_duration='1D',\n", + " predef_cluster_order=cluster_order, # Reuse cluster assignments\n", + ")\n", + "\n", + "# Optimize the modified system\n", + "fs_modified_clustered.optimize(solver)\n", + "\n", + "print('\\nComparison (same cluster structure):')\n", + "print(f' Original storage size: {fs_clustered.statistics.sizes[\"Storage\"].item():.0f}')\n", + "print(f' Modified storage size: {fs_modified_clustered.statistics.sizes[\"Storage\"].item():.0f}')\n", + "print(f' Original cost: {fs_clustered.solution[\"costs\"].item():,.0f} €')\n", + "print(f' Modified cost: {fs_modified_clustered.solution[\"costs\"].item():,.0f} €')" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, "source": [ "## Method 3: Two-Stage Workflow (Recommended)\n", "\n", @@ -317,7 +361,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -338,7 +382,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -361,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "24", "metadata": {}, "source": [ "## Compare Results" @@ -370,7 +414,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -419,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "26", "metadata": {}, "source": [ "## Expand Solution to Full Resolution\n", @@ -431,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -445,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -467,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "27", + "id": "29", "metadata": {}, "source": [ "## Visualize Clustered Heat Balance" @@ -476,7 +520,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -486,16 +530,16 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "31", "metadata": {}, "outputs": [], "source": [ - "fs_expanded.statistics.plot.storage('Storage').data.to_dataframe()" + "fs_expanded.statistics.plot.storage('Storage')" ] }, { "cell_type": "markdown", - "id": "30", + "id": "32", "metadata": {}, "source": [ "## API Reference\n", @@ -548,7 +592,7 @@ }, { "cell_type": "markdown", - "id": "31", + "id": "33", "metadata": {}, "source": [ "## Summary\n", @@ -561,6 +605,7 @@ "- **Expand solutions** back to full resolution with `expand_solution()`\n", "- Access **clustering metadata** via `fs.clustering` (metrics, cluster_order, occurrences)\n", "- Use **advanced options** like different algorithms and reproducible random states\n", + "- **Manually assign clusters** using `predef_cluster_order`\n", "\n", "### Key Takeaways\n", "\n", @@ -570,6 +615,7 @@ "4. **Storage handling** is configurable via `cluster_mode`\n", "5. **Use `random_state`** for reproducible results\n", "6. **Check metrics** to evaluate clustering quality\n", + "7. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", "\n", "### Next Steps\n", "\n", diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 43b15d440..e3a41a3ba 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -582,13 +582,11 @@ def cluster( weights: dict[str, float] | None = None, time_series_for_high_peaks: list[str] | None = None, time_series_for_low_peaks: list[str] | None = None, - cluster_method: Literal['k_means', 'k_medoids', 'hierarchical', 'k_maxoids', 'averaging'] = 'k_means', + cluster_method: Literal['k_means', 'k_medoids', 'hierarchical', 'k_maxoids', 'averaging'] = 'hierarchical', representation_method: Literal[ 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' - ] = 'meanRepresentation', - extreme_period_method: Literal[ - 'None', 'append', 'new_cluster_center', 'replace_cluster_center' - ] = 'new_cluster_center', + ] = 'medoidRepresentation', + extreme_period_method: Literal['append', 'new_cluster_center', 'replace_cluster_center'] | None = None, rescale_cluster_periods: bool = True, random_state: int | None = None, predef_cluster_order: xr.DataArray | np.ndarray | list[int] | None = None, @@ -602,7 +600,7 @@ def cluster( through time series aggregation using the tsam package. The method: - 1. Performs time series clustering using tsam (k-means) + 1. Performs time series clustering using tsam (hierarchical by default) 2. Extracts only the typical clusters (not all original timesteps) 3. Applies timestep weighting for accurate cost representation 4. Handles storage states between clusters based on each Storage's ``cluster_mode`` @@ -619,18 +617,19 @@ def cluster( clusters. **Recommended** for demand time series to capture peak demand days. time_series_for_low_peaks: Time series labels for explicitly selecting low-value clusters. cluster_method: Clustering algorithm to use. Options: - ``'k_means'`` (default), ``'k_medoids'``, ``'hierarchical'``, + ``'hierarchical'`` (default), ``'k_means'``, ``'k_medoids'``, ``'k_maxoids'``, ``'averaging'``. representation_method: How cluster representatives are computed. Options: - ``'meanRepresentation'`` (default), ``'medoidRepresentation'``, + ``'medoidRepresentation'`` (default), ``'meanRepresentation'``, ``'distributionAndMinMaxRepresentation'``. extreme_period_method: How extreme periods (peaks) are integrated. Options: - ``'new_cluster_center'`` (default), ``'None'``, ``'append'``, - ``'replace_cluster_center'``. + ``None`` (default, no special handling), ``'append'``, + ``'new_cluster_center'``, ``'replace_cluster_center'``. rescale_cluster_periods: If True (default), rescale cluster periods so their weighted mean matches the original time series mean. - random_state: Random seed for reproducible clustering results. If None, - results may vary between runs. + random_state: Random seed for reproducible clustering results. Only relevant + for non-deterministic methods like ``'k_means'``. The default + ``'hierarchical'`` method is deterministic. predef_cluster_order: Predefined cluster assignments for manual clustering. Array of cluster indices (0 to n_clusters-1) for each original period. If provided, clustering is skipped and these assignments are used directly. @@ -743,13 +742,15 @@ def cluster( # Use tsam directly clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) + # tsam expects 'None' as a string, not Python None + tsam_extreme_method = 'None' if extreme_period_method is None else extreme_period_method tsam_agg = tsam.TimeSeriesAggregation( df, noTypicalPeriods=n_clusters, hoursPerPeriod=hours_per_cluster, resolution=dt, clusterMethod=cluster_method, - extremePeriodMethod=extreme_period_method, + extremePeriodMethod=tsam_extreme_method, representationMethod=representation_method, rescaleClusterPeriods=rescale_cluster_periods, predefClusterOrder=predef_order_slice, From b2539d86a166c0473aedf46e95ca02df9a3d4286 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:45:41 +0100 Subject: [PATCH 15/30] fix scatter plot faceting --- flixopt/dataset_plot_accessor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 73b20b436..e2802cb04 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -590,13 +590,15 @@ def scatter( if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), y: ylabel} - if actual_facet_col: + # Only use facets if the column actually exists in the dataframe + # (scatter uses wide format, so 'variable' column doesn't exist) + if actual_facet_col and actual_facet_col in df.columns: fig_kwargs['facet_col'] = actual_facet_col if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: + if actual_facet_row and actual_facet_row in df.columns: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: + if actual_anim and actual_anim in df.columns: fig_kwargs['animation_frame'] = actual_anim return px.scatter(**fig_kwargs) From e48ff177f177f3f7e4b6584a44f80abd739cc07f Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:48:56 +0100 Subject: [PATCH 16/30] =?UTF-8?q?=E2=8F=BA=20Fixed=20the=20documentation?= =?UTF-8?q?=20in=20the=20notebook:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Cell 32 (API Reference table): Updated defaults to 'hierarchical', 'medoidRepresentation', and None 2. Cell 16: Swapped the example to show k_means as the alternative (since hierarchical is now default) 3. Cell 17: Updated variable names to match 4. Cell 33 (Key Takeaways): Clarified that random_state is only needed for non-deterministic methods like 'k_means' The code review --- docs/notebooks/08c-clustering.ipynb | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index 4d8b8a121..35b91f6eb 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -257,16 +257,16 @@ "outputs": [], "source": [ "# Try different clustering algorithms\n", - "fs_hierarchical = flow_system.transform.cluster(\n", + "fs_kmeans = flow_system.transform.cluster(\n", " n_clusters=8,\n", " cluster_duration='1D',\n", - " cluster_method='hierarchical', # Alternative: 'k_means' (default), 'k_medoids', 'averaging'\n", + " cluster_method='k_means', # Alternative: 'hierarchical' (default), 'k_medoids', 'averaging'\n", " random_state=42,\n", ")\n", "\n", "# Compare cluster assignments between algorithms\n", - "print('k_means clusters: ', fs_clustered.clustering.cluster_order.values)\n", - "print('hierarchical clusters:', fs_hierarchical.clustering.cluster_order.values)" + "print('hierarchical clusters:', fs_clustered.clustering.cluster_order.values)\n", + "print('k_means clusters: ', fs_kmeans.clustering.cluster_order.values)" ] }, { @@ -279,10 +279,10 @@ "# Compare RMSE between algorithms\n", "print('Quality comparison (RMSE for HeatDemand):')\n", "print(\n", - " f' k_means: {float(fs_clustered.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + " f' hierarchical: {float(fs_clustered.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", ")\n", "print(\n", - " f' hierarchical: {float(fs_hierarchical.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + " f' k_means: {float(fs_kmeans.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", ")" ] }, @@ -553,11 +553,11 @@ "| `weights` | `dict[str, float]` | None | Optional weights for time series in clustering |\n", "| `time_series_for_high_peaks` | `list[str]` | None | **Essential**: Force inclusion of peak periods |\n", "| `time_series_for_low_peaks` | `list[str]` | None | Force inclusion of minimum periods |\n", - "| `cluster_method` | `str` | 'k_means' | Algorithm: 'k_means', 'hierarchical', 'k_medoids', 'k_maxoids', 'averaging' |\n", - "| `representation_method` | `str` | 'meanRepresentation' | 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' |\n", - "| `extreme_period_method` | `str` | 'new_cluster_center' | How peaks are integrated: 'None', 'append', 'new_cluster_center', 'replace_cluster_center' |\n", + "| `cluster_method` | `str` | 'hierarchical' | Algorithm: 'hierarchical', 'k_means', 'k_medoids', 'k_maxoids', 'averaging' |\n", + "| `representation_method` | `str` | 'medoidRepresentation' | 'medoidRepresentation', 'meanRepresentation', 'distributionAndMinMaxRepresentation' |\n", + "| `extreme_period_method` | `str \\| None` | None | How peaks are integrated: None, 'append', 'new_cluster_center', 'replace_cluster_center' |\n", "| `rescale_cluster_periods` | `bool` | True | Rescale clusters to match original means |\n", - "| `random_state` | `int` | None | Random seed for reproducibility |\n", + "| `random_state` | `int` | None | Random seed for reproducibility (only needed for non-deterministic methods like 'k_means') |\n", "| `predef_cluster_order` | `array` | None | Manual cluster assignments |\n", "| `**tsam_kwargs` | - | - | Additional tsam parameters |\n", "\n", @@ -613,7 +613,7 @@ "2. **Add safety margin** (5-10%) when fixing sizes from clustering\n", "3. **Two-stage is recommended**: clustering for sizing, full resolution for dispatch\n", "4. **Storage handling** is configurable via `cluster_mode`\n", - "5. **Use `random_state`** for reproducible results\n", + "5. **Use `random_state`** for reproducible results when using non-deterministic methods like 'k_means' (the default 'hierarchical' is deterministic)\n", "6. **Check metrics** to evaluate clustering quality\n", "7. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", "\n", From 285e07b59b6550be0f45ec8e44d09ed4c26cf623 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:53:30 +0100 Subject: [PATCH 17/30] 1. Error handling for accuracyIndicators() - Added try/except with warning log and empty DataFrame fallback, plus handling empty DataFrames when building the metrics Dataset 2. Random state to tsam - Replaced global np.random.seed() with passing seed parameter directly to tsam's TimeSeriesAggregation 3. tsam_kwargs conflict validation - Added validation that raises ValueError if user tries to override explicit parameters via **tsam_kwargs (including seed) 4. predef_cluster_order validation - Added dimension validation for DataArray inputs, checking they match the FlowSystem's period/scenario structure 5. Out-of-bounds fix - Clamped last_original_cluster_idx to n_original_clusters - 1 to handle partial clusters at the end --- flixopt/transform_accessor.py | 104 +++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 26 deletions(-) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index e3a41a3ba..210725d7d 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -707,15 +707,46 @@ def cluster( ds = self._fs.to_dataset(include_solution=False) + # Validate tsam_kwargs doesn't override explicit parameters + reserved_tsam_keys = { + 'noTypicalPeriods', + 'hoursPerPeriod', + 'resolution', + 'clusterMethod', + 'extremePeriodMethod', + 'representationMethod', + 'rescaleClusterPeriods', + 'predefClusterOrder', + 'weightDict', + 'addPeakMax', + 'addPeakMin', + 'seed', # Controlled by random_state parameter + } + conflicts = reserved_tsam_keys & set(tsam_kwargs.keys()) + if conflicts: + raise ValueError( + f'Cannot override explicit parameters via tsam_kwargs: {conflicts}. ' + f'Use the corresponding cluster() parameters instead.' + ) + + # Validate predef_cluster_order dimensions if it's a DataArray + if isinstance(predef_cluster_order, xr.DataArray): + expected_dims = {'original_cluster'} + if has_periods: + expected_dims.add('period') + if has_scenarios: + expected_dims.add('scenario') + if set(predef_cluster_order.dims) != expected_dims: + raise ValueError( + f'predef_cluster_order dimensions {set(predef_cluster_order.dims)} ' + f'do not match expected {expected_dims} for this FlowSystem.' + ) + # Cluster each (period, scenario) combination using tsam directly tsam_results: dict[tuple, tsam.TimeSeriesAggregation] = {} cluster_orders: dict[tuple, np.ndarray] = {} cluster_occurrences_all: dict[tuple, dict] = {} - # Set random seed for reproducibility - if random_state is not None: - np.random.seed(random_state) - # Collect metrics per (period, scenario) slice clustering_metrics_all: dict[tuple, pd.DataFrame] = {} @@ -744,21 +775,24 @@ def cluster( clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) # tsam expects 'None' as a string, not Python None tsam_extreme_method = 'None' if extreme_period_method is None else extreme_period_method - tsam_agg = tsam.TimeSeriesAggregation( - df, - noTypicalPeriods=n_clusters, - hoursPerPeriod=hours_per_cluster, - resolution=dt, - clusterMethod=cluster_method, - extremePeriodMethod=tsam_extreme_method, - representationMethod=representation_method, - rescaleClusterPeriods=rescale_cluster_periods, - predefClusterOrder=predef_order_slice, - weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, - addPeakMax=time_series_for_high_peaks or [], - addPeakMin=time_series_for_low_peaks or [], - **tsam_kwargs, - ) + # Build tsam kwargs, including random_state if provided + tsam_init_kwargs: dict[str, Any] = { + 'noTypicalPeriods': n_clusters, + 'hoursPerPeriod': hours_per_cluster, + 'resolution': dt, + 'clusterMethod': cluster_method, + 'extremePeriodMethod': tsam_extreme_method, + 'representationMethod': representation_method, + 'rescaleClusterPeriods': rescale_cluster_periods, + 'predefClusterOrder': predef_order_slice, + 'weightDict': {name: w for name, w in clustering_weights.items() if name in df.columns}, + 'addPeakMax': time_series_for_high_peaks or [], + 'addPeakMin': time_series_for_low_peaks or [], + } + # Pass random_state to tsam instead of setting global np.random.seed() + if random_state is not None: + tsam_init_kwargs['seed'] = random_state + tsam_agg = tsam.TimeSeriesAggregation(df, **tsam_init_kwargs, **tsam_kwargs) # Suppress tsam warning about minimal value constraints (informational, not actionable) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, message='.*minimal value.*exceeds.*') @@ -767,16 +801,26 @@ def cluster( tsam_results[key] = tsam_agg cluster_orders[key] = tsam_agg.clusterOrder cluster_occurrences_all[key] = tsam_agg.clusterPeriodNoOccur - clustering_metrics_all[key] = tsam_agg.accuracyIndicators() + # Compute accuracy metrics with error handling + try: + clustering_metrics_all[key] = tsam_agg.accuracyIndicators() + except Exception as e: + logger.warning(f'Failed to compute clustering metrics for {key}: {e}') + clustering_metrics_all[key] = pd.DataFrame() # Use first result for structure first_key = (periods[0], scenarios[0]) first_tsam = tsam_results[first_key] # Convert metrics to xr.Dataset with period/scenario dims if multi-dimensional - if len(clustering_metrics_all) == 1: + # Filter out empty DataFrames (from failed accuracyIndicators calls) + non_empty_metrics = {k: v for k, v in clustering_metrics_all.items() if not v.empty} + if not non_empty_metrics: + # All metrics failed - create empty Dataset + clustering_metrics = xr.Dataset() + elif len(non_empty_metrics) == 1 or len(clustering_metrics_all) == 1: # Simple case: convert single DataFrame to Dataset - metrics_df = clustering_metrics_all[first_key] + metrics_df = non_empty_metrics.get(first_key) or next(iter(non_empty_metrics.values())) clustering_metrics = xr.Dataset( { col: xr.DataArray( @@ -787,8 +831,8 @@ def cluster( ) else: # Multi-dim case: combine metrics into Dataset with period/scenario dims - # First, get the metric columns from any DataFrame - sample_df = next(iter(clustering_metrics_all.values())) + # First, get the metric columns from any non-empty DataFrame + sample_df = next(iter(non_empty_metrics.values())) metric_names = list(sample_df.columns) time_series_names = list(sample_df.index) @@ -798,7 +842,11 @@ def cluster( # Shape: (time_series, period?, scenario?) slices = {} for (p, s), df in clustering_metrics_all.items(): - slices[(p, s)] = xr.DataArray(df[metric].values, dims=['time_series']) + if df.empty: + # Use NaN for failed metrics + slices[(p, s)] = xr.DataArray(np.full(len(time_series_names), np.nan), dims=['time_series']) + else: + slices[(p, s)] = xr.DataArray(df[metric].values, dims=['time_series']) da = self._combine_slices_to_dataarray_generic(slices, ['time_series'], periods, scenarios, metric) da = da.assign_coords(time_series=time_series_names) @@ -1254,7 +1302,11 @@ def expand_solution(self) -> FlowSystem: # Expand function using ClusterResult.expand_data() - handles multi-dimensional cases # For charge_state with cluster dim, also includes the extra timestep - last_original_cluster_idx = (n_original_timesteps - 1) // timesteps_per_cluster + # Clamp to valid bounds to handle partial clusters at the end + last_original_cluster_idx = min( + (n_original_timesteps - 1) // timesteps_per_cluster, + n_original_clusters - 1, + ) def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: if 'time' not in da.dims: From c126115dfa00df77ce78adca1755e796ab421ec5 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 20:43:47 +0100 Subject: [PATCH 18/30] 1. DataFrame truth ambiguity - Changed non_empty_metrics.get(first_key) or next(...) to explicit if metrics_df is None: check 2. removed random state --- docs/notebooks/08c-clustering.ipynb | 10 ++---- flixopt/transform_accessor.py | 42 ++++++++++------------- tests/test_clustering/test_integration.py | 10 +++--- 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index 35b91f6eb..9676b6992 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -147,7 +147,6 @@ " n_clusters=8, # 8 typical days\n", " cluster_duration='1D', # Daily clustering\n", " time_series_for_high_peaks=peak_series, # Capture peak demand day\n", - " random_state=42, # Reproducible results\n", ")\n", "\n", "time_clustering = timeit.default_timer() - start\n", @@ -261,7 +260,6 @@ " n_clusters=8,\n", " cluster_duration='1D',\n", " cluster_method='k_means', # Alternative: 'hierarchical' (default), 'k_medoids', 'averaging'\n", - " random_state=42,\n", ")\n", "\n", "# Compare cluster assignments between algorithms\n", @@ -557,7 +555,6 @@ "| `representation_method` | `str` | 'medoidRepresentation' | 'medoidRepresentation', 'meanRepresentation', 'distributionAndMinMaxRepresentation' |\n", "| `extreme_period_method` | `str \\| None` | None | How peaks are integrated: None, 'append', 'new_cluster_center', 'replace_cluster_center' |\n", "| `rescale_cluster_periods` | `bool` | True | Rescale clusters to match original means |\n", - "| `random_state` | `int` | None | Random seed for reproducibility (only needed for non-deterministic methods like 'k_means') |\n", "| `predef_cluster_order` | `array` | None | Manual cluster assignments |\n", "| `**tsam_kwargs` | - | - | Additional tsam parameters |\n", "\n", @@ -604,7 +601,7 @@ "- Use **two-stage optimization** for fast yet accurate investment decisions\n", "- **Expand solutions** back to full resolution with `expand_solution()`\n", "- Access **clustering metadata** via `fs.clustering` (metrics, cluster_order, occurrences)\n", - "- Use **advanced options** like different algorithms and reproducible random states\n", + "- Use **advanced options** like different algorithms\n", "- **Manually assign clusters** using `predef_cluster_order`\n", "\n", "### Key Takeaways\n", @@ -613,9 +610,8 @@ "2. **Add safety margin** (5-10%) when fixing sizes from clustering\n", "3. **Two-stage is recommended**: clustering for sizing, full resolution for dispatch\n", "4. **Storage handling** is configurable via `cluster_mode`\n", - "5. **Use `random_state`** for reproducible results when using non-deterministic methods like 'k_means' (the default 'hierarchical' is deterministic)\n", - "6. **Check metrics** to evaluate clustering quality\n", - "7. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", + "5. **Check metrics** to evaluate clustering quality\n", + "6. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", "\n", "### Next Steps\n", "\n", diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 210725d7d..6a5b51caa 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -588,7 +588,6 @@ def cluster( ] = 'medoidRepresentation', extreme_period_method: Literal['append', 'new_cluster_center', 'replace_cluster_center'] | None = None, rescale_cluster_periods: bool = True, - random_state: int | None = None, predef_cluster_order: xr.DataArray | np.ndarray | list[int] | None = None, **tsam_kwargs: Any, ) -> FlowSystem: @@ -627,9 +626,6 @@ def cluster( ``'new_cluster_center'``, ``'replace_cluster_center'``. rescale_cluster_periods: If True (default), rescale cluster periods so their weighted mean matches the original time series mean. - random_state: Random seed for reproducible clustering results. Only relevant - for non-deterministic methods like ``'k_means'``. The default - ``'hierarchical'`` method is deterministic. predef_cluster_order: Predefined cluster assignments for manual clustering. Array of cluster indices (0 to n_clusters-1) for each original period. If provided, clustering is skipped and these assignments are used directly. @@ -720,7 +716,6 @@ def cluster( 'weightDict', 'addPeakMax', 'addPeakMin', - 'seed', # Controlled by random_state parameter } conflicts = reserved_tsam_keys & set(tsam_kwargs.keys()) if conflicts: @@ -775,24 +770,21 @@ def cluster( clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) # tsam expects 'None' as a string, not Python None tsam_extreme_method = 'None' if extreme_period_method is None else extreme_period_method - # Build tsam kwargs, including random_state if provided - tsam_init_kwargs: dict[str, Any] = { - 'noTypicalPeriods': n_clusters, - 'hoursPerPeriod': hours_per_cluster, - 'resolution': dt, - 'clusterMethod': cluster_method, - 'extremePeriodMethod': tsam_extreme_method, - 'representationMethod': representation_method, - 'rescaleClusterPeriods': rescale_cluster_periods, - 'predefClusterOrder': predef_order_slice, - 'weightDict': {name: w for name, w in clustering_weights.items() if name in df.columns}, - 'addPeakMax': time_series_for_high_peaks or [], - 'addPeakMin': time_series_for_low_peaks or [], - } - # Pass random_state to tsam instead of setting global np.random.seed() - if random_state is not None: - tsam_init_kwargs['seed'] = random_state - tsam_agg = tsam.TimeSeriesAggregation(df, **tsam_init_kwargs, **tsam_kwargs) + tsam_agg = tsam.TimeSeriesAggregation( + df, + noTypicalPeriods=n_clusters, + hoursPerPeriod=hours_per_cluster, + resolution=dt, + clusterMethod=cluster_method, + extremePeriodMethod=tsam_extreme_method, + representationMethod=representation_method, + rescaleClusterPeriods=rescale_cluster_periods, + predefClusterOrder=predef_order_slice, + weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, + addPeakMax=time_series_for_high_peaks or [], + addPeakMin=time_series_for_low_peaks or [], + **tsam_kwargs, + ) # Suppress tsam warning about minimal value constraints (informational, not actionable) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, message='.*minimal value.*exceeds.*') @@ -820,7 +812,9 @@ def cluster( clustering_metrics = xr.Dataset() elif len(non_empty_metrics) == 1 or len(clustering_metrics_all) == 1: # Simple case: convert single DataFrame to Dataset - metrics_df = non_empty_metrics.get(first_key) or next(iter(non_empty_metrics.values())) + metrics_df = non_empty_metrics.get(first_key) + if metrics_df is None: + metrics_df = next(iter(non_empty_metrics.values())) clustering_metrics = xr.Dataset( { col: xr.DataArray( diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index d6dd3d2e7..16c638c95 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -201,12 +201,12 @@ def test_cluster_method_parameter(self, basic_flow_system): ) assert len(fs_clustered.clusters) == 2 - def test_random_state_reproducibility(self, basic_flow_system): - """Test that random_state produces reproducible results.""" - fs1 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) - fs2 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) + def test_hierarchical_is_deterministic(self, basic_flow_system): + """Test that hierarchical clustering (default) produces deterministic results.""" + fs1 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') + fs2 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') - # Same random state should produce identical cluster orders + # Hierarchical clustering should produce identical cluster orders xr.testing.assert_equal(fs1.clustering.cluster_order, fs2.clustering.cluster_order) def test_metrics_available(self, basic_flow_system): From 14887214f59bea1d60b66cafeebd7d884e307798 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 21:51:45 +0100 Subject: [PATCH 19/30] Fix pie plot animation frame and add warnings for unassigned dims --- flixopt/dataset_plot_accessor.py | 37 ++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index e2802cb04..23c88bdb7 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from typing import Any, Literal import pandas as pd @@ -12,6 +13,8 @@ from .color_processing import ColorType, process_colors from .config import CONFIG +logger = logging.getLogger('flixopt') + def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: """Select x-axis dim from priority list, or 'variable' for scalar data. @@ -93,6 +96,25 @@ def _resolve_auto_facets( used.add(next_dim) results[slot_name] = next_dim + # Warn if any dimensions were not assigned to any slot + # Only count slots that were available (passed as 'auto' or explicit dim, not None) + available_slot_count = sum(1 for v in slots.values() if v is not None) + unassigned = available - used + if unassigned: + if available_slot_count < 4: + # Some slots weren't available (e.g., pie doesn't support animation_frame) + unavailable_slots = [k for k, v in slots.items() if v is None] + logger.warning( + f'Dimensions {unassigned} not assigned to any plot dimension. ' + f'Not available for this plot type: {unavailable_slots}. ' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + ) + else: + logger.warning( + f'Dimensions {unassigned} not assigned to color/facet/animation. ' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + ) + return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] @@ -610,21 +632,22 @@ def pie( title: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a pie chart from aggregated dataset values. - Extra dimensions are auto-assigned to facet_col, facet_row, and animation_frame. + Extra dimensions are auto-assigned to facet_col and facet_row. For scalar values, a single pie is shown. + Note: + ``px.pie()`` does not support animation_frame, so only facets are available. + Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. facet_row: Dimension for row facets. 'auto' uses CONFIG priority. - animation_frame: Dimension for animation slider. 'auto' uses CONFIG priority. facet_cols: Number of columns in facet grid wrap. **px_kwargs: Additional arguments passed to plotly.express.pie. @@ -654,14 +677,12 @@ def pie( **px_kwargs, ) - # Multi-dimensional case - faceted/animated pies + # Multi-dimensional case - faceted pies (px.pie doesn't support animation_frame) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, None, facet_col, facet_row, animation_frame - ) + _, actual_facet_col, actual_facet_row, _ = _resolve_auto_facets(self._ds, None, facet_col, facet_row, None) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -680,8 +701,6 @@ def pie( fig_kwargs['facet_col_wrap'] = facet_col_wrap if actual_facet_row: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: - fig_kwargs['animation_frame'] = actual_anim return px.pie(**fig_kwargs) From e18966bc8db978444e841453520c27aeb94a4f31 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 11:43:40 +0100 Subject: [PATCH 20/30] Change logger warning to regular warning --- flixopt/dataset_plot_accessor.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 735393da8..9377aed2e 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -2,7 +2,7 @@ from __future__ import annotations -import logging +import warnings from typing import Any, Literal import pandas as pd @@ -13,8 +13,6 @@ from .color_processing import ColorType, process_colors from .config import CONFIG -logger = logging.getLogger('flixopt') - def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: """Select x-axis dim from priority list, or 'variable' for scalar data. @@ -104,15 +102,17 @@ def _resolve_auto_facets( if available_slot_count < 4: # Some slots weren't available (e.g., pie doesn't support animation_frame) unavailable_slots = [k for k, v in slots.items() if v is None] - logger.warning( + warnings.warn( f'Dimensions {unassigned} not assigned to any plot dimension. ' f'Not available for this plot type: {unavailable_slots}. ' - f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', + stacklevel=3, ) else: - logger.warning( + warnings.warn( f'Dimensions {unassigned} not assigned to color/facet/animation. ' - f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', + stacklevel=3, ) return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] From 87ce35139c70e86ba5ecda8669c45566f545c1c6 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:04:37 +0100 Subject: [PATCH 21/30] =?UTF-8?q?=E2=8F=BA=20The=20centralized=20slot=20as?= =?UTF-8?q?signment=20system=20is=20now=20complete.=20Here's=20a=20summary?= =?UTF-8?q?=20of=20the=20changes=20made:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes Made 1. flixopt/config.py - Replaced three separate config attributes (extra_dim_priority, dim_slot_priority, x_dim_priority) with a single unified dim_priority tuple - Updated CONFIG.Plotting class docstring and attribute definitions - Updated to_dict() method to use the new attribute - The new priority order: ('time', 'duration', 'duration_pct', 'variable', 'cluster', 'period', 'scenario') 2. flixopt/dataset_plot_accessor.py - Created new assign_slots() function that centralizes all dimension-to-slot assignment logic - Fixed slot fill order: x → color → facet_col → facet_row → animation_frame - Updated all plot methods (bar, stacked_bar, line, area, heatmap, scatter, pie) to use assign_slots() - Removed old _get_x_dim() and _resolve_auto_facets() functions - Updated docstrings to reference dim_priority instead of x_dim_priority 3. flixopt/statistics_accessor.py - Updated _resolve_auto_facets() to use the new assign_slots() function internally - Added import for assign_slots from dataset_plot_accessor Key Design Decisions - Single priority list controls all auto-assignment - Slots are filled in fixed order based on availability - None means a slot is not available for that plot type - 'auto' triggers auto-assignment from priority list - Explicit string values override auto-assignment --- flixopt/config.py | 23 +-- flixopt/dataset_plot_accessor.py | 328 +++++++++++++++---------------- flixopt/statistics_accessor.py | 39 +--- 3 files changed, 171 insertions(+), 219 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index 9793f9ba2..fce943eb1 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -164,9 +164,7 @@ def format(self, record): 'default_sequential_colorscale': 'turbo', 'default_qualitative_colorscale': 'plotly', 'default_line_shape': 'hv', - 'extra_dim_priority': ('variable', 'cluster', 'period', 'scenario'), - 'dim_slot_priority': ('color', 'facet_col', 'facet_row', 'animation_frame'), - 'x_dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'period', 'scenario', 'cluster'), + 'dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'cluster', 'period', 'scenario'), } ), 'solving': MappingProxyType( @@ -562,9 +560,9 @@ class Plotting: default_facet_cols: Default number of columns for faceted plots. default_sequential_colorscale: Default colorscale for heatmaps and continuous data. default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). - extra_dim_priority: Order of extra dimensions when auto-assigning to slots. - dim_slot_priority: Order of slots to fill with extra dimensions. - x_dim_priority: Order of dimensions to prefer for x-axis when 'auto'. + dim_priority: Priority order for assigning dimensions to plot slots (x, color, facet, etc.). + Dimensions are assigned to slots in order: x → y → color → facet_col → facet_row → animation_frame. + 'value' represents the y-axis values (from data_var names after melting). Examples: ```python @@ -573,9 +571,8 @@ class Plotting: CONFIG.Plotting.default_sequential_colorscale = 'plasma' CONFIG.Plotting.default_qualitative_colorscale = 'Dark24' - # Customize dimension handling for faceting - CONFIG.Plotting.extra_dim_priority = ('scenario', 'period', 'cluster') - CONFIG.Plotting.dim_slot_priority = ('facet_row', 'facet_col', 'animation_frame') + # Customize dimension priority for auto-assignment + CONFIG.Plotting.dim_priority = ('time', 'scenario', 'variable', 'period', 'cluster') ``` """ @@ -586,9 +583,7 @@ class Plotting: default_sequential_colorscale: str = _DEFAULTS['plotting']['default_sequential_colorscale'] default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] default_line_shape: str = _DEFAULTS['plotting']['default_line_shape'] - extra_dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['extra_dim_priority'] - dim_slot_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_slot_priority'] - x_dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['x_dim_priority'] + dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_priority'] class Carriers: """Default carrier definitions for common energy types. @@ -690,9 +685,7 @@ def to_dict(cls) -> dict: 'default_sequential_colorscale': cls.Plotting.default_sequential_colorscale, 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, 'default_line_shape': cls.Plotting.default_line_shape, - 'extra_dim_priority': cls.Plotting.extra_dim_priority, - 'dim_slot_priority': cls.Plotting.dim_slot_priority, - 'x_dim_priority': cls.Plotting.x_dim_priority, + 'dim_priority': cls.Plotting.dim_priority, }, } diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 9377aed2e..746227e45 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -14,94 +14,86 @@ from .config import CONFIG -def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: - """Select x-axis dim from priority list, or 'variable' for scalar data. - - Args: - dims: List of available dimensions. - n_data_vars: Number of data variables (for 'variable' availability). - x: Explicit x-axis choice or 'auto'. - """ - if x and x != 'auto': - return x - - # 'variable' is available when there are multiple data_vars - available = set(dims) - if n_data_vars > 1: - available.add('variable') - - # Check priority list first - for dim in CONFIG.Plotting.x_dim_priority: - if dim in available: - return dim - - # Fallback to first available dimension, or 'variable' for scalar data - return dims[0] if dims else 'variable' - - -def _resolve_auto_facets( +def assign_slots( ds: xr.Dataset, - color: str | Literal['auto'] | None, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, - animation_frame: str | Literal['auto'] | None = None, - exclude_dims: set[str] | None = None, -) -> tuple[str | None, str | None, str | None, str | None]: - """Assign 'auto' facet slots from available dims using CONFIG priority lists. + *, + x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', +) -> dict[str, str | None]: + """Assign dimensions to plot slots using CONFIG.Plotting.dim_priority. + + Slot fill order: x → color → facet_col → facet_row → animation_frame. + Dimensions are assigned in priority order from CONFIG.Plotting.dim_priority. + + Slot values: + - 'auto': auto-assign from available dims using priority + - None: skip this slot (not available for this plot type) + - str: use this specific dimension + + 'variable' is treated as a dimension when len(data_vars) > 1. It represents + the data_var names column in the melted DataFrame. - 'variable' is treated like a dimension - available when len(data_vars) > 1. - It exists in the melted DataFrame from data_var names, not in ds.dims. + Args: + ds: Dataset to analyze for available dimensions. + x: X-axis dimension. 'auto' assigns first available from priority. + color: Color grouping dimension. + facet_col: Column faceting dimension. + facet_row: Row faceting dimension. + animation_frame: Animation slider dimension. Returns: - Tuple of (color, facet_col, facet_row, animation_frame). + Dict with keys 'x', 'color', 'facet_col', 'facet_row', 'animation_frame' + and values being assigned dimension names (or None if slot skipped/unfilled). """ - # Get available extra dimensions with size > 1, excluding specified dims - exclude = exclude_dims or set() - available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} - # 'variable' is available when there are multiple data_vars (and not excluded) - if len(ds.data_vars) > 1 and 'variable' not in exclude: + # Get available dimensions with size > 1 + available = {d for d in ds.dims if ds.sizes[d] > 1} + # 'variable' is available when there are multiple data_vars + if len(ds.data_vars) > 1: available.add('variable') - extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] - used: set[str] = set() - # Map slot names to their input values + # Get priority-ordered list of available dims + priority_dims = [d for d in CONFIG.Plotting.dim_priority if d in available] + # Add any available dims not in priority list (fallback) + priority_dims.extend(d for d in available if d not in priority_dims) + + # Slot specification in fill order slots = { + 'x': x, 'color': color, 'facet_col': facet_col, 'facet_row': facet_row, 'animation_frame': animation_frame, } - results: dict[str, str | None] = { - 'color': None, - 'facet_col': None, - 'facet_row': None, - 'animation_frame': None, - } + # Fixed fill order for 'auto' assignment + slot_order = ('x', 'color', 'facet_col', 'facet_row', 'animation_frame') + + results: dict[str, str | None] = {k: None for k in slot_order} + used: set[str] = set() # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used for slot_name, value in slots.items(): if value is not None and value != 'auto': - if value in available and value not in used: - used.add(value) - results[slot_name] = value - - # Second pass: resolve 'auto' slots in dim_slot_priority order - dim_iter = iter(d for d in extra_dims if d not in used) - for slot_name in CONFIG.Plotting.dim_slot_priority: - if slots.get(slot_name) == 'auto': + used.add(value) + results[slot_name] = value + + # Second pass: resolve 'auto' slots in fixed fill order + dim_iter = iter(d for d in priority_dims if d not in used) + for slot_name in slot_order: + if slots[slot_name] == 'auto': next_dim = next(dim_iter, None) if next_dim: used.add(next_dim) results[slot_name] = next_dim # Warn if any dimensions were not assigned to any slot - # Only count slots that were available (passed as 'auto' or explicit dim, not None) - available_slot_count = sum(1 for v in slots.values() if v is not None) unassigned = available - used if unassigned: - if available_slot_count < 4: - # Some slots weren't available (e.g., pie doesn't support animation_frame) - unavailable_slots = [k for k, v in slots.items() if v is None] + available_slots = [k for k, v in slots.items() if v is not None] + unavailable_slots = [k for k, v in slots.items() if v is None] + if unavailable_slots: warnings.warn( f'Dimensions {unassigned} not assigned to any plot dimension. ' f'Not available for this plot type: {unavailable_slots}. ' @@ -110,12 +102,12 @@ def _resolve_auto_facets( ) else: warnings.warn( - f'Dimensions {unassigned} not assigned to color/facet/animation. ' + f'Dimensions {unassigned} not assigned to any plot dimension ({available_slots}). ' f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', stacklevel=3, ) - return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] + return results def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: @@ -184,7 +176,7 @@ def bar( """Create a grouped bar chart from the dataset. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -200,11 +192,8 @@ def bar( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -212,7 +201,7 @@ def bar( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -220,27 +209,27 @@ def bar( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'barmode': 'group', } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.bar(**{**fig_kwargs, **px_kwargs}) @@ -265,7 +254,7 @@ def stacked_bar( values are stacked separately. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -281,11 +270,8 @@ def stacked_bar( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -293,7 +279,7 @@ def stacked_bar( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -301,26 +287,26 @@ def stacked_bar( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] fig = px.bar(**{**fig_kwargs, **px_kwargs}) fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) @@ -348,7 +334,7 @@ def line( Each variable in the dataset becomes a separate line. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -366,11 +352,8 @@ def line( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -378,7 +361,7 @@ def line( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -386,27 +369,27 @@ def line( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.line(**{**fig_kwargs, **px_kwargs}) @@ -429,7 +412,7 @@ def area( """Create a stacked area chart from the dataset. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -446,11 +429,8 @@ def area( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -458,7 +438,7 @@ def area( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -466,27 +446,27 @@ def area( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.area(**{**fig_kwargs, **px_kwargs}) @@ -537,7 +517,10 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - _, actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, None, facet_col, None, animation_frame) + # Heatmap uses imshow - x/y come from array axes, color is continuous + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame + ) imshow_args: dict[str, Any] = { 'img': da, @@ -545,13 +528,13 @@ def heatmap( 'title': title or variable, } - if actual_facet_col and actual_facet_col in da.dims: - imshow_args['facet_col'] = actual_facet_col - if facet_col_wrap < da.sizes[actual_facet_col]: + if slots['facet_col'] and slots['facet_col'] in da.dims: + imshow_args['facet_col'] = slots['facet_col'] + if facet_col_wrap < da.sizes[slots['facet_col']]: imshow_args['facet_col_wrap'] = facet_col_wrap - if actual_anim and actual_anim in da.dims: - imshow_args['animation_frame'] = actual_anim + if slots['animation_frame'] and slots['animation_frame'] in da.dims: + imshow_args['animation_frame'] = slots['animation_frame'] return px.imshow(**{**imshow_args, **imshow_kwargs}) @@ -595,8 +578,9 @@ def scatter( if df.empty: return go.Figure() - _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, None, facet_col, facet_row, animation_frame + # Scatter uses explicit x/y variable names, not dimensions + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -614,14 +598,14 @@ def scatter( # Only use facets if the column actually exists in the dataframe # (scatter uses wide format, so 'variable' column doesn't exist) - if actual_facet_col and actual_facet_col in df.columns: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and slots['facet_col'] in df.columns: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and actual_facet_row in df.columns: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and actual_anim in df.columns: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and slots['facet_row'] in df.columns: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and slots['animation_frame'] in df.columns: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.scatter(**fig_kwargs) @@ -682,8 +666,10 @@ def pie( if df.empty: return go.Figure() - # Note: px.pie doesn't support animation_frame - actual_facet_col, actual_facet_row, _ = _resolve_auto_facets(self._ds, facet_col, facet_row, None) + # Pie uses 'variable' for names and 'value' for values, no x/color/animation_frame + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=None + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -696,12 +682,12 @@ def pie( **px_kwargs, } - if actual_facet_col: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col']: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: - fig_kwargs['facet_row'] = actual_facet_row + if slots['facet_row']: + fig_kwargs['facet_row'] = slots['facet_row'] return px.pie(**fig_kwargs) @@ -940,10 +926,10 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - # Use Dataset for facet resolution + # Heatmap uses imshow - x/y come from array axes, color is continuous ds_for_resolution = da.to_dataset(name='_temp') - _, actual_facet_col, _, actual_anim = _resolve_auto_facets( - ds_for_resolution, None, facet_col, None, animation_frame + slots = assign_slots( + ds_for_resolution, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame ) imshow_args: dict[str, Any] = { @@ -952,12 +938,12 @@ def heatmap( 'title': title or (da.name if da.name else ''), } - if actual_facet_col and actual_facet_col in da.dims: - imshow_args['facet_col'] = actual_facet_col - if facet_col_wrap < da.sizes[actual_facet_col]: + if slots['facet_col'] and slots['facet_col'] in da.dims: + imshow_args['facet_col'] = slots['facet_col'] + if facet_col_wrap < da.sizes[slots['facet_col']]: imshow_args['facet_col_wrap'] = facet_col_wrap - if actual_anim and actual_anim in da.dims: - imshow_args['animation_frame'] = actual_anim + if slots['animation_frame'] and slots['animation_frame'] in da.dims: + imshow_args['animation_frame'] = slots['animation_frame'] return px.imshow(**{**imshow_args, **imshow_kwargs}) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 382ed1bf0..dc43287fc 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,6 +31,7 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG +from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -188,9 +189,7 @@ def _resolve_auto_facets( ) -> tuple[str | None, str | None, str | None]: """Resolve 'auto' facet/animation dimensions based on available data dimensions. - When 'auto' is specified, extra dimensions are assigned to slots based on: - - CONFIG.Plotting.extra_dim_priority: Order of dimensions to assign. - - CONFIG.Plotting.dim_slot_priority: Order of slots to fill. + Uses assign_slots with x=None and color=None to only resolve facet/animation slots. Args: ds: Dataset to check for available dimensions. @@ -202,36 +201,10 @@ def _resolve_auto_facets( Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). Each is either a valid dimension name or None. """ - # Get available extra dimensions with size > 1, sorted by priority - available = {d for d in ds.dims if ds.sizes[d] > 1} - extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] - used: set[str] = set() - - # Map slot names to their input values - slots = { - 'facet_col': facet_col, - 'facet_row': facet_row, - 'animation_frame': animation_frame, - } - results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} - - # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used - for slot_name, value in slots.items(): - if value is not None and value != 'auto': - if value in available and value not in used: - used.add(value) - results[slot_name] = value - - # Second pass: resolve 'auto' slots in dim_slot_priority order - dim_iter = iter(d for d in extra_dims if d not in used) - for slot_name in CONFIG.Plotting.dim_slot_priority: - if slots.get(slot_name) == 'auto': - next_dim = next(dim_iter, None) - if next_dim: - used.add(next_dim) - results[slot_name] = next_dim - - return results['facet_col'], results['facet_row'], results['animation_frame'] + slots = assign_slots( + ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + ) + return slots['facet_col'], slots['facet_row'], slots['animation_frame'] def _resolve_facets( From 947ccd9615f78c70cd38f336f9d715a84278c028 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:07:33 +0100 Subject: [PATCH 22/30] Add slot_order to config --- flixopt/config.py | 13 ++++++++++--- flixopt/dataset_plot_accessor.py | 8 ++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index fce943eb1..3bc3d5ebf 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -165,6 +165,7 @@ def format(self, record): 'default_qualitative_colorscale': 'plotly', 'default_line_shape': 'hv', 'dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'cluster', 'period', 'scenario'), + 'slot_priority': ('x', 'color', 'facet_col', 'facet_row', 'animation_frame'), } ), 'solving': MappingProxyType( @@ -560,9 +561,10 @@ class Plotting: default_facet_cols: Default number of columns for faceted plots. default_sequential_colorscale: Default colorscale for heatmaps and continuous data. default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). - dim_priority: Priority order for assigning dimensions to plot slots (x, color, facet, etc.). - Dimensions are assigned to slots in order: x → y → color → facet_col → facet_row → animation_frame. - 'value' represents the y-axis values (from data_var names after melting). + dim_priority: Priority order for assigning dimensions to plot slots. + Dimensions are assigned to slots based on this order. + slot_priority: Order in which slots are filled during auto-assignment. + Default: x → color → facet_col → facet_row → animation_frame. Examples: ```python @@ -573,6 +575,9 @@ class Plotting: # Customize dimension priority for auto-assignment CONFIG.Plotting.dim_priority = ('time', 'scenario', 'variable', 'period', 'cluster') + + # Change slot fill order (e.g., prioritize facets over color) + CONFIG.Plotting.slot_priority = ('x', 'facet_col', 'facet_row', 'color', 'animation_frame') ``` """ @@ -584,6 +589,7 @@ class Plotting: default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] default_line_shape: str = _DEFAULTS['plotting']['default_line_shape'] dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_priority'] + slot_priority: tuple[str, ...] = _DEFAULTS['plotting']['slot_priority'] class Carriers: """Default carrier definitions for common energy types. @@ -686,6 +692,7 @@ def to_dict(cls) -> dict: 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, 'default_line_shape': cls.Plotting.default_line_shape, 'dim_priority': cls.Plotting.dim_priority, + 'slot_priority': cls.Plotting.slot_priority, }, } diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 746227e45..ee3e82399 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -59,7 +59,7 @@ def assign_slots( # Add any available dims not in priority list (fallback) priority_dims.extend(d for d in available if d not in priority_dims) - # Slot specification in fill order + # Slot specification slots = { 'x': x, 'color': color, @@ -67,8 +67,8 @@ def assign_slots( 'facet_row': facet_row, 'animation_frame': animation_frame, } - # Fixed fill order for 'auto' assignment - slot_order = ('x', 'color', 'facet_col', 'facet_row', 'animation_frame') + # Slot fill order from config + slot_order = CONFIG.Plotting.slot_priority results: dict[str, str | None] = {k: None for k in slot_order} used: set[str] = set() @@ -79,7 +79,7 @@ def assign_slots( used.add(value) results[slot_name] = value - # Second pass: resolve 'auto' slots in fixed fill order + # Second pass: resolve 'auto' slots in config-defined fill order dim_iter = iter(d for d in priority_dims if d not in used) for slot_name in slot_order: if slots[slot_name] == 'auto': From b1336f6de867d8a22d780f0a4d353bfbc774fbab Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:20:21 +0100 Subject: [PATCH 23/30] Add new assign_slots() method --- flixopt/dataset_plot_accessor.py | 162 +++++++++------------- flixopt/statistics_accessor.py | 224 +++++++++++++++++-------------- 2 files changed, 184 insertions(+), 202 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index ee3e82399..9afce67dd 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -22,11 +22,11 @@ def assign_slots( facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', + exclude_dims: set[str] | None = None, ) -> dict[str, str | None]: """Assign dimensions to plot slots using CONFIG.Plotting.dim_priority. - Slot fill order: x → color → facet_col → facet_row → animation_frame. - Dimensions are assigned in priority order from CONFIG.Plotting.dim_priority. + Dimensions are assigned in priority order to slots based on CONFIG.Plotting.slot_priority. Slot values: - 'auto': auto-assign from available dims using priority @@ -43,15 +43,17 @@ def assign_slots( facet_col: Column faceting dimension. facet_row: Row faceting dimension. animation_frame: Animation slider dimension. + exclude_dims: Dimensions to exclude from auto-assignment (e.g., already used for x elsewhere). Returns: Dict with keys 'x', 'color', 'facet_col', 'facet_row', 'animation_frame' and values being assigned dimension names (or None if slot skipped/unfilled). """ - # Get available dimensions with size > 1 - available = {d for d in ds.dims if ds.sizes[d] > 1} - # 'variable' is available when there are multiple data_vars - if len(ds.data_vars) > 1: + # Get available dimensions with size > 1, excluding specified dims + exclude = exclude_dims or set() + available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} + # 'variable' is available when there are multiple data_vars (and not excluded) + if len(ds.data_vars) > 1 and 'variable' not in exclude: available.add('variable') # Get priority-ordered list of available dims @@ -110,6 +112,34 @@ def assign_slots( return results +def _build_fig_kwargs( + slots: dict[str, str | None], + ds_sizes: dict[str, int], + px_kwargs: dict[str, Any], + facet_cols: int | None = None, +) -> dict[str, Any]: + """Build plotly express kwargs from slot assignments. + + Adds facet/animation args only if slots are assigned and not overridden in px_kwargs. + Handles facet_col_wrap based on dimension size. + """ + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + result: dict[str, Any] = {} + + # Add facet/animation kwargs from slots (skip if None or already in px_kwargs) + for slot in ('color', 'facet_col', 'facet_row', 'animation_frame'): + if slots.get(slot) and slot not in px_kwargs: + result[slot] = slots[slot] + + # Add facet_col_wrap if facet_col is set and dimension is large enough + if result.get('facet_col'): + dim_size = ds_sizes.get(result['facet_col'], facet_col_wrap + 1) + if facet_col_wrap < dim_size: + result['facet_col_wrap'] = facet_col_wrap + + return result + + def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: """Convert Dataset to long-form DataFrame for Plotly Express.""" if not ds.data_vars: @@ -195,42 +225,24 @@ def bar( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, 'barmode': 'group', + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - return px.bar(**{**fig_kwargs, **px_kwargs}) def stacked_bar( @@ -273,41 +285,23 @@ def stacked_bar( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - fig = px.bar(**{**fig_kwargs, **px_kwargs}) fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) fig.update_traces(marker_line_width=0) @@ -355,42 +349,24 @@ def line( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - return px.line(**{**fig_kwargs, **px_kwargs}) def area( @@ -432,42 +408,24 @@ def area( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - return px.area(**{**fig_kwargs, **px_kwargs}) def heatmap( diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index dc43287fc..da6a859f9 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -181,43 +181,8 @@ def _filter_by_carrier(ds: xr.Dataset, carrier: str | list[str] | None) -> xr.Da return ds[matching_vars] if matching_vars else xr.Dataset() -def _resolve_auto_facets( - ds: xr.Dataset, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, - animation_frame: str | Literal['auto'] | None = None, -) -> tuple[str | None, str | None, str | None]: - """Resolve 'auto' facet/animation dimensions based on available data dimensions. - - Uses assign_slots with x=None and color=None to only resolve facet/animation slots. - - Args: - ds: Dataset to check for available dimensions. - facet_col: Dimension name, 'auto', or None. - facet_row: Dimension name, 'auto', or None. - animation_frame: Dimension name, 'auto', or None. - - Returns: - Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). - Each is either a valid dimension name or None. - """ - slots = assign_slots( - ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame - ) - return slots['facet_col'], slots['facet_row'], slots['animation_frame'] - - -def _resolve_facets( - ds: xr.Dataset, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, -) -> tuple[str | None, str | None]: - """Resolve facet dimensions, returning None if not present in data. - - Legacy wrapper for _resolve_auto_facets for backward compatibility. - """ - resolved_col, resolved_row, _ = _resolve_auto_facets(ds, facet_col, facet_row, None) - return resolved_col, resolved_row +# Default dimensions to exclude from facet auto-assignment (typically x-axis dimensions) +_FACET_EXCLUDE_DIMS = {'time', 'duration', 'duration_pct'} def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: @@ -1355,8 +1320,14 @@ def balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Build color map from Element.color attributes if no colors specified @@ -1372,9 +1343,9 @@ def balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{node} [{unit_label}]' if unit_label else node, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1466,8 +1437,14 @@ def carrier_balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Use cached component colors for flows @@ -1496,9 +1473,9 @@ def carrier_balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1570,17 +1547,22 @@ def heatmap( # Determine facet and animation from available dims has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - if has_multiple_vars: - actual_facet = 'variable' - # Resolve animation using auto logic, excluding 'variable' which is used for facet - _, _, actual_animation = _resolve_auto_facets(da.to_dataset(name='value'), None, None, animation_frame) - if actual_animation == 'variable': - actual_animation = None - else: - # Resolve facet and animation using auto logic - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame - ) + # Get slot assignments (heatmap only uses facet_col and animation_frame) + slots = assign_slots( + da.to_dataset(name='value'), + x=None, + color=None, + facet_col='variable' if has_multiple_vars else facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, + ) + resolved_facet = slots['facet_col'] + resolved_anim = slots['animation_frame'] + + # Don't use 'variable' for animation if it's used for facet + if resolved_anim == 'variable' and has_multiple_vars: + resolved_anim = None # Determine heatmap dimensions based on data structure if is_clustered and (reshape == 'auto' or reshape is None): @@ -1588,21 +1570,27 @@ def heatmap( heatmap_dims = ['time', 'cluster'] elif reshape and reshape != 'auto' and 'time' in da.dims: # Non-clustered with explicit reshape: reshape time to (day, hour) etc. - # Extra dims will be handled via facet/animation or dropped da = _reshape_time_for_heatmap(da, reshape) heatmap_dims = ['timestep', 'timeframe'] elif reshape == 'auto' and 'time' in da.dims and not is_clustered: # Auto mode for non-clustered: use default ('D', 'h') reshape - # Extra dims will be handled via facet/animation or dropped da = _reshape_time_for_heatmap(da, ('D', 'h')) heatmap_dims = ['timestep', 'timeframe'] elif has_multiple_vars: # Can't reshape but have multiple vars: use variable + time as heatmap axes heatmap_dims = ['variable', 'time'] - # variable is now a heatmap dim, use period/scenario for facet/animation - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame + # variable is now a heatmap dim, reassign facet + slots = assign_slots( + da.to_dataset(name='value'), + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS | {'variable'}, ) + resolved_facet = slots['facet_col'] + resolved_anim = slots['animation_frame'] else: # Fallback: use first two available dimensions available_dims = [d for d in da.dims if da.sizes[d] > 1] @@ -1614,12 +1602,12 @@ def heatmap( heatmap_dims = list(da.dims)[:1] # Keep only dims we need - keep_dims = set(heatmap_dims) | {d for d in [actual_facet, actual_animation] if d is not None} + keep_dims = set(heatmap_dims) | {d for d in [resolved_facet, resolved_anim] if d is not None} for dim in [d for d in da.dims if d not in keep_dims]: da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) # Transpose to expected order - dim_order = heatmap_dims + [d for d in [actual_facet, actual_animation] if d] + dim_order = heatmap_dims + [d for d in [resolved_facet, resolved_anim] if d] da = da.transpose(*dim_order) # Clear name for multiple variables (colorbar would show first var's name) @@ -1628,8 +1616,8 @@ def heatmap( fig = da.fxplot.heatmap( colors=colors, - facet_col=actual_facet, - animation_frame=actual_animation, + facet_col=resolved_facet, + animation_frame=resolved_anim, **plotly_kwargs, ) @@ -1710,8 +1698,14 @@ def flows( ds = ds[[lbl for lbl in matching_labels if lbl in ds]] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Get unit label from first data variable's attributes @@ -1723,9 +1717,9 @@ def flows( fig = ds.fxplot.line( colors=colors, title=f'Flows [{unit_label}]' if unit_label else 'Flows', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1771,8 +1765,14 @@ def sizes( valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] ds = ds[valid_labels] - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) df = _dataset_to_long_df(ds) @@ -1786,9 +1786,9 @@ def sizes( x='variable', y='value', color='variable', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], color_discrete_map=color_map, title='Investment Sizes', labels={'variable': 'Flow', 'value': 'Size'}, @@ -1886,8 +1886,14 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: duration_coord = np.linspace(0, 100, n_timesteps) if normalize else np.arange(n_timesteps) result_ds = result_ds.assign_coords({duration_name: duration_coord}) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - result_ds, facet_col, facet_row, animation_frame + slots = assign_slots( + result_ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Get unit label from first data variable's attributes @@ -1899,9 +1905,9 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: fig = result_ds.fxplot.line( colors=colors, title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -2031,8 +2037,14 @@ def effects( raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") # Resolve facets - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - combined.to_dataset(name='value'), facet_col, facet_row, animation_frame + slots = assign_slots( + combined.to_dataset(name='value'), + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Convert to DataFrame for plotly express @@ -2060,9 +2072,9 @@ def effects( y='value', color=color_col, color_discrete_map=color_map, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], title=title, **plotly_kwargs, ) @@ -2111,16 +2123,22 @@ def charge_states( ds = ds[[s for s in storages if s in ds]] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) fig = ds.fxplot.line( colors=colors, title='Storage Charge States', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) fig.update_yaxes(title_text='Charge State') @@ -2204,8 +2222,14 @@ def storage( # Apply selection ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Build color map @@ -2227,9 +2251,9 @@ def storage( x='time', y='value', color='variable', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], color_discrete_map=color_map, title=f'{storage} Operation ({unit})', **plotly_kwargs, @@ -2244,9 +2268,9 @@ def storage( charge_df, x='time', y='value', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], ) # Get the primary y-axes from the bar figure to create matching secondary axes From 28bb631a9460645c737f6fe57a54a9ff63cff2af Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:33:14 +0100 Subject: [PATCH 24/30] Add new assign_slots() method --- flixopt/dataset_plot_accessor.py | 38 ++++++- flixopt/statistics_accessor.py | 171 ++++++++----------------------- 2 files changed, 75 insertions(+), 134 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 9afce67dd..6451037f5 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -201,6 +201,7 @@ def bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a grouped bar chart from the dataset. @@ -217,13 +218,20 @@ def bar( facet_row: Dimension for row facets. 'auto' uses CONFIG priority. animation_frame: Dimension for animation slider. facet_cols: Number of columns in facet grid wrap. + exclude_dims: Dimensions to exclude from auto-assignment. **px_kwargs: Additional arguments passed to plotly.express.bar. Returns: Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -258,6 +266,7 @@ def stacked_bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked bar chart from the dataset. @@ -283,7 +292,13 @@ def stacked_bar( Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -321,6 +336,7 @@ def line( animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a line chart from the dataset. @@ -347,7 +363,13 @@ def line( Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -383,6 +405,7 @@ def area( animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked area chart from the dataset. @@ -406,7 +429,13 @@ def area( Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -777,6 +806,7 @@ def stacked_bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked bar chart. See DatasetPlotAccessor.stacked_bar for details.""" diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index da6a859f9..26a4bf55c 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,7 +31,6 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG -from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -181,10 +180,6 @@ def _filter_by_carrier(ds: xr.Dataset, carrier: str | list[str] | None) -> xr.Da return ds[matching_vars] if matching_vars else xr.Dataset() -# Default dimensions to exclude from facet auto-assignment (typically x-axis dimensions) -_FACET_EXCLUDE_DIMS = {'time', 'duration', 'duration_pct'} - - def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: """Convert xarray Dataset to long-form DataFrame for plotly express.""" if not ds.data_vars: @@ -1320,15 +1315,6 @@ def balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Build color map from Element.color attributes if no colors specified if colors is None: @@ -1343,9 +1329,9 @@ def balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{node} [{unit_label}]' if unit_label else node, - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1437,15 +1423,6 @@ def carrier_balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Use cached component colors for flows if colors is None: @@ -1473,9 +1450,9 @@ def carrier_balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1547,18 +1524,18 @@ def heatmap( # Determine facet and animation from available dims has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - # Get slot assignments (heatmap only uses facet_col and animation_frame) - slots = assign_slots( - da.to_dataset(name='value'), - x=None, - color=None, - facet_col='variable' if has_multiple_vars else facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) - resolved_facet = slots['facet_col'] - resolved_anim = slots['animation_frame'] + # For heatmap, facet defaults to 'variable' if multiple vars + # Resolve 'auto' to None for heatmap (no auto-faceting by time etc.) + if facet_col == 'auto': + resolved_facet = 'variable' if has_multiple_vars else None + else: + resolved_facet = facet_col + + # Resolve animation_frame - 'auto' means None for heatmap (no auto-animation) + if animation_frame == 'auto': + resolved_anim = None + else: + resolved_anim = animation_frame # Don't use 'variable' for animation if it's used for facet if resolved_anim == 'variable' and has_multiple_vars: @@ -1579,18 +1556,8 @@ def heatmap( elif has_multiple_vars: # Can't reshape but have multiple vars: use variable + time as heatmap axes heatmap_dims = ['variable', 'time'] - # variable is now a heatmap dim, reassign facet - slots = assign_slots( - da.to_dataset(name='value'), - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS | {'variable'}, - ) - resolved_facet = slots['facet_col'] - resolved_anim = slots['animation_frame'] + # variable is now a heatmap dim, use user's facet choice + resolved_facet = facet_col else: # Fallback: use first two available dimensions available_dims = [d for d in da.dims if da.sizes[d] > 1] @@ -1698,15 +1665,6 @@ def flows( ds = ds[[lbl for lbl in matching_labels if lbl in ds]] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Get unit label from first data variable's attributes unit_label = '' @@ -1717,9 +1675,9 @@ def flows( fig = ds.fxplot.line( colors=colors, title=f'Flows [{unit_label}]' if unit_label else 'Flows', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1765,16 +1723,6 @@ def sizes( valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] ds = ds[valid_labels] - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) - df = _dataset_to_long_df(ds) if df.empty: fig = go.Figure() @@ -1786,9 +1734,9 @@ def sizes( x='variable', y='value', color='variable', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, color_discrete_map=color_map, title='Investment Sizes', labels={'variable': 'Flow', 'value': 'Size'}, @@ -1886,16 +1834,6 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: duration_coord = np.linspace(0, 100, n_timesteps) if normalize else np.arange(n_timesteps) result_ds = result_ds.assign_coords({duration_name: duration_coord}) - slots = assign_slots( - result_ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) - # Get unit label from first data variable's attributes unit_label = '' if ds.data_vars: @@ -1905,9 +1843,9 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: fig = result_ds.fxplot.line( colors=colors, title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -2037,15 +1975,6 @@ def effects( raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") # Resolve facets - slots = assign_slots( - combined.to_dataset(name='value'), - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Convert to DataFrame for plotly express df = combined.to_dataframe(name='value').reset_index() @@ -2072,9 +2001,9 @@ def effects( y='value', color=color_col, color_discrete_map=color_map, - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, title=title, **plotly_kwargs, ) @@ -2123,22 +2052,13 @@ def charge_states( ds = ds[[s for s in storages if s in ds]] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) fig = ds.fxplot.line( colors=colors, title='Storage Charge States', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) fig.update_yaxes(title_text='Charge State') @@ -2222,15 +2142,6 @@ def storage( # Apply selection ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Build color map flow_labels = [lbl for lbl in ds.data_vars if lbl != 'charge_state'] @@ -2251,9 +2162,9 @@ def storage( x='time', y='value', color='variable', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, color_discrete_map=color_map, title=f'{storage} Operation ({unit})', **plotly_kwargs, @@ -2268,9 +2179,9 @@ def storage( charge_df, x='time', y='value', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, ) # Get the primary y-axes from the bar figure to create matching secondary axes From 4f8407a6353adc32195789a8393f65aefd1802a0 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:54:45 +0100 Subject: [PATCH 25/30] Fix heatmap and convert all to use fxplot --- flixopt/dataset_plot_accessor.py | 24 +++++- flixopt/statistics_accessor.py | 131 +++++++++++++------------------ 2 files changed, 73 insertions(+), 82 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 6451037f5..a310c94b8 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -504,9 +504,17 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - # Heatmap uses imshow - x/y come from array axes, color is continuous + # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap + # Exclude these from slot assignment + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() slots = assign_slots( - self._ds, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame + self._ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, ) imshow_args: dict[str, Any] = { @@ -914,10 +922,18 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - # Heatmap uses imshow - x/y come from array axes, color is continuous + # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap + # Exclude these from slot assignment + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() ds_for_resolution = da.to_dataset(name='_temp') slots = assign_slots( - ds_for_resolution, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame + ds_for_resolution, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, ) imshow_args: dict[str, Any] = { diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 26a4bf55c..1df6704ea 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,6 +31,7 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG +from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -1520,28 +1521,9 @@ def heatmap( # Check if data is clustered (has cluster dimension with size > 1) is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 - - # Determine facet and animation from available dims has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - # For heatmap, facet defaults to 'variable' if multiple vars - # Resolve 'auto' to None for heatmap (no auto-faceting by time etc.) - if facet_col == 'auto': - resolved_facet = 'variable' if has_multiple_vars else None - else: - resolved_facet = facet_col - - # Resolve animation_frame - 'auto' means None for heatmap (no auto-animation) - if animation_frame == 'auto': - resolved_anim = None - else: - resolved_anim = animation_frame - - # Don't use 'variable' for animation if it's used for facet - if resolved_anim == 'variable' and has_multiple_vars: - resolved_anim = None - - # Determine heatmap dimensions based on data structure + # Apply time reshape if needed (creates timestep/timeframe dims) if is_clustered and (reshape == 'auto' or reshape is None): # Clustered data: use (time, cluster) as natural 2D heatmap axes heatmap_dims = ['time', 'cluster'] @@ -1556,26 +1538,34 @@ def heatmap( elif has_multiple_vars: # Can't reshape but have multiple vars: use variable + time as heatmap axes heatmap_dims = ['variable', 'time'] - # variable is now a heatmap dim, use user's facet choice - resolved_facet = facet_col else: # Fallback: use first two available dimensions available_dims = [d for d in da.dims if da.sizes[d] > 1] - if len(available_dims) >= 2: - heatmap_dims = available_dims[:2] - elif 'time' in da.dims: - heatmap_dims = ['time'] - else: - heatmap_dims = list(da.dims)[:1] + heatmap_dims = available_dims[:2] if len(available_dims) >= 2 else list(da.dims)[:2] + + # Resolve facet/animation using assign_slots, excluding heatmap dims + ds_temp = da.to_dataset(name='_temp') + slots = assign_slots( + ds_temp, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=set(heatmap_dims), + ) - # Keep only dims we need - keep_dims = set(heatmap_dims) | {d for d in [resolved_facet, resolved_anim] if d is not None} + # Keep only dims we need (heatmap axes + facet/animation) + keep_dims = set(heatmap_dims) | {d for d in [slots['facet_col'], slots['animation_frame']] if d} for dim in [d for d in da.dims if d not in keep_dims]: da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) - # Transpose to expected order - dim_order = heatmap_dims + [d for d in [resolved_facet, resolved_anim] if d] - da = da.transpose(*dim_order) + # Transpose to expected order (heatmap dims first) + dim_order = [d for d in heatmap_dims if d in da.dims] + [ + d for d in [slots['facet_col'], slots['animation_frame']] if d and d in da.dims + ] + if len(dim_order) == len(da.dims): + da = da.transpose(*dim_order) # Clear name for multiple variables (colorbar would show first var's name) if has_multiple_vars: @@ -1583,8 +1573,8 @@ def heatmap( fig = da.fxplot.heatmap( colors=colors, - facet_col=resolved_facet, - animation_frame=resolved_anim, + facet_col=slots['facet_col'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1723,23 +1713,18 @@ def sizes( valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] ds = ds[valid_labels] - df = _dataset_to_long_df(ds) - if df.empty: + if not ds.data_vars: fig = go.Figure() else: - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables) - fig = px.bar( - df, + fig = ds.fxplot.bar( x='variable', - y='value', color='variable', + colors=colors, + title='Investment Sizes', + ylabel='Size', facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, - color_discrete_map=color_map, - title='Investment Sizes', - labels={'variable': 'Flow', 'value': 'Size'}, **plotly_kwargs, ) @@ -1974,11 +1959,14 @@ def effects( else: raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") - # Resolve facets - # Convert to DataFrame for plotly express df = combined.to_dataframe(name='value').reset_index() + # Resolve facet/animation: 'auto' means None for DataFrames (no dimension priority) + resolved_facet_col = None if facet_col == 'auto' else facet_col + resolved_facet_row = None if facet_row == 'auto' else facet_row + resolved_animation = None if animation_frame == 'auto' else animation_frame + # Build color map if color_col and color_col in df.columns: color_items = df[color_col].unique().tolist() @@ -2001,9 +1989,9 @@ def effects( y='value', color=color_col, color_discrete_map=color_map, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, + facet_col=resolved_facet_col, + facet_row=resolved_facet_row, + animation_frame=resolved_animation, title=title, **plotly_kwargs, ) @@ -2143,57 +2131,45 @@ def storage( # Apply selection ds = _apply_selection(ds, select) - # Build color map + # Separate flow data from charge_state flow_labels = [lbl for lbl in ds.data_vars if lbl != 'charge_state'] + flow_ds = ds[flow_labels] + charge_da = ds['charge_state'] + + # Build color map for flows if colors is None: colors = self._get_color_map_for_balance(storage, flow_labels) - color_map = process_colors(colors, flow_labels) - color_map['charge_state'] = 'black' - # Convert to long-form DataFrame - df = _dataset_to_long_df(ds) - - # Create figure with facets using px.bar for flows - flow_df = df[df['variable'] != 'charge_state'] - charge_df = df[df['variable'] == 'charge_state'] - - fig = px.bar( - flow_df, + # Create stacked bar chart for flows using fxplot + fig = flow_ds.fxplot.stacked_bar( x='time', - y='value', color='variable', + colors=colors, + title=f'{storage} Operation ({unit})', facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, - color_discrete_map=color_map, - title=f'{storage} Operation ({unit})', **plotly_kwargs, ) - fig.update_layout(bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) # Add charge state as line on secondary y-axis - if not charge_df.empty: - # Create line figure with same facets to get matching trace structure - line_fig = px.line( - charge_df, + if charge_da.size > 0: + # Create line figure with same facets + line_fig = charge_da.fxplot.line( x='time', - y='value', + color=None, # Single line, no color grouping facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, ) # Get the primary y-axes from the bar figure to create matching secondary axes - # px creates axes named: yaxis, yaxis2, yaxis3, etc. primary_yaxes = [key for key in fig.layout if key.startswith('yaxis')] # For each primary y-axis, create a secondary y-axis for i, primary_key in enumerate(sorted(primary_yaxes, key=lambda x: int(x[5:]) if x[5:] else 0)): - # Determine secondary axis name (y -> y2, y2 -> y3 pattern won't work) - # Instead use a consistent offset: yaxis -> yaxis10, yaxis2 -> yaxis11, etc. primary_num = primary_key[5:] if primary_key[5:] else '1' - secondary_num = int(primary_num) + 100 # Use high offset to avoid conflicts + secondary_num = int(primary_num) + 100 secondary_key = f'yaxis{secondary_num}' secondary_anchor = f'x{primary_num}' if primary_num != '1' else 'x' @@ -2207,14 +2183,13 @@ def storage( # Add line traces with correct axis assignments for i, trace in enumerate(line_fig.data): - # Map trace index to secondary y-axis primary_num = i + 1 if i > 0 else 1 secondary_yaxis = f'y{primary_num + 100}' trace.name = 'charge_state' trace.line = dict(color=charge_state_color, width=2) trace.yaxis = secondary_yaxis - trace.showlegend = i == 0 # Only show legend for first trace + trace.showlegend = i == 0 trace.legendgroup = 'charge_state' fig.add_trace(trace) From a4d46811c889b652f78b1ae1f1a95a0af96d1777 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:08:01 +0100 Subject: [PATCH 26/30] Fix heatmap --- flixopt/dataset_plot_accessor.py | 88 ++++++++++++++++++++------------ flixopt/statistics_accessor.py | 48 ++++++++--------- 2 files changed, 75 insertions(+), 61 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index a310c94b8..47cb0564a 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -505,17 +505,24 @@ def heatmap( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap - # Exclude these from slot assignment - heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() - slots = assign_slots( - self._ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=heatmap_axes, - ) + # Only call assign_slots if we need to resolve 'auto' values + if facet_col == 'auto' or animation_frame == 'auto': + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() + slots = assign_slots( + self._ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, + ) + resolved_facet = slots['facet_col'] + resolved_animation = slots['animation_frame'] + else: + # Values already resolved (or None), use directly without re-resolving + resolved_facet = facet_col + resolved_animation = animation_frame imshow_args: dict[str, Any] = { 'img': da, @@ -523,13 +530,17 @@ def heatmap( 'title': title or variable, } - if slots['facet_col'] and slots['facet_col'] in da.dims: - imshow_args['facet_col'] = slots['facet_col'] - if facet_col_wrap < da.sizes[slots['facet_col']]: + if resolved_facet and resolved_facet in da.dims: + imshow_args['facet_col'] = resolved_facet + if facet_col_wrap < da.sizes[resolved_facet]: imshow_args['facet_col_wrap'] = facet_col_wrap - if slots['animation_frame'] and slots['animation_frame'] in da.dims: - imshow_args['animation_frame'] = slots['animation_frame'] + if resolved_animation and resolved_animation in da.dims: + imshow_args['animation_frame'] = resolved_animation + + # Use binary_string=False to handle non-numeric coords (e.g., string labels) + if 'binary_string' not in imshow_kwargs: + imshow_args['binary_string'] = False return px.imshow(**{**imshow_args, **imshow_kwargs}) @@ -923,18 +934,25 @@ def heatmap( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap - # Exclude these from slot assignment - heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() - ds_for_resolution = da.to_dataset(name='_temp') - slots = assign_slots( - ds_for_resolution, - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=heatmap_axes, - ) + # Only call assign_slots if we need to resolve 'auto' values + if facet_col == 'auto' or animation_frame == 'auto': + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() + ds_for_resolution = da.to_dataset(name='_temp') + slots = assign_slots( + ds_for_resolution, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, + ) + resolved_facet = slots['facet_col'] + resolved_animation = slots['animation_frame'] + else: + # Values already resolved (or None), use directly without re-resolving + resolved_facet = facet_col + resolved_animation = animation_frame imshow_args: dict[str, Any] = { 'img': da, @@ -942,12 +960,16 @@ def heatmap( 'title': title or (da.name if da.name else ''), } - if slots['facet_col'] and slots['facet_col'] in da.dims: - imshow_args['facet_col'] = slots['facet_col'] - if facet_col_wrap < da.sizes[slots['facet_col']]: + if resolved_facet and resolved_facet in da.dims: + imshow_args['facet_col'] = resolved_facet + if facet_col_wrap < da.sizes[resolved_facet]: imshow_args['facet_col_wrap'] = facet_col_wrap - if slots['animation_frame'] and slots['animation_frame'] in da.dims: - imshow_args['animation_frame'] = slots['animation_frame'] + if resolved_animation and resolved_animation in da.dims: + imshow_args['animation_frame'] = resolved_animation + + # Use binary_string=False to handle non-numeric coords (e.g., string labels) + if 'binary_string' not in imshow_kwargs: + imshow_args['binary_string'] = False return px.imshow(**{**imshow_args, **imshow_kwargs}) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 1df6704ea..2ac9060ac 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,7 +31,6 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG -from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -1472,7 +1471,7 @@ def heatmap( reshape: tuple[str, str] | Literal['auto'] | None = 'auto', colors: str | list[str] | None = None, facet_col: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -1523,6 +1522,11 @@ def heatmap( is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 + # Count extra dims (beyond time) - if too many, skip reshape to avoid dimension explosion + extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] + # Max dims for heatmap: 2 axes + facet_col + animation_frame = 4 + can_reshape = len(extra_dims) <= 2 # Leave room for facet and animation + # Apply time reshape if needed (creates timestep/timeframe dims) if is_clustered and (reshape == 'auto' or reshape is None): # Clustered data: use (time, cluster) as natural 2D heatmap axes @@ -1531,8 +1535,8 @@ def heatmap( # Non-clustered with explicit reshape: reshape time to (day, hour) etc. da = _reshape_time_for_heatmap(da, reshape) heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and 'time' in da.dims and not is_clustered: - # Auto mode for non-clustered: use default ('D', 'h') reshape + elif reshape == 'auto' and 'time' in da.dims and not is_clustered and can_reshape: + # Auto mode for non-clustered: use default ('D', 'h') reshape only if not too many dims da = _reshape_time_for_heatmap(da, ('D', 'h')) heatmap_dims = ['timestep', 'timeframe'] elif has_multiple_vars: @@ -1543,38 +1547,26 @@ def heatmap( available_dims = [d for d in da.dims if da.sizes[d] > 1] heatmap_dims = available_dims[:2] if len(available_dims) >= 2 else list(da.dims)[:2] - # Resolve facet/animation using assign_slots, excluding heatmap dims - ds_temp = da.to_dataset(name='_temp') - slots = assign_slots( - ds_temp, - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=set(heatmap_dims), - ) - - # Keep only dims we need (heatmap axes + facet/animation) - keep_dims = set(heatmap_dims) | {d for d in [slots['facet_col'], slots['animation_frame']] if d} - for dim in [d for d in da.dims if d not in keep_dims]: - da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) + # Transpose so heatmap dims come first (px.imshow uses first 2 dims as y/x axes) + other_dims = [d for d in da.dims if d not in heatmap_dims] + dim_order = [d for d in heatmap_dims if d in da.dims] + other_dims + # Always transpose to ensure correct dim order (even if seemingly equal, xarray dim order matters) + da = da.transpose(*dim_order) - # Transpose to expected order (heatmap dims first) - dim_order = [d for d in heatmap_dims if d in da.dims] + [ - d for d in [slots['facet_col'], slots['animation_frame']] if d and d in da.dims - ] - if len(dim_order) == len(da.dims): - da = da.transpose(*dim_order) + # Squeeze single-element dims (except heatmap axes) to avoid 3D shape errors + for dim in list(da.dims): + if dim not in heatmap_dims and da.sizes[dim] == 1: + da = da.squeeze(dim, drop=True) # Clear name for multiple variables (colorbar would show first var's name) if has_multiple_vars: da = da.rename('') + # Let fxplot handle slot assignment for facet/animation fig = da.fxplot.heatmap( colors=colors, - facet_col=slots['facet_col'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + animation_frame=animation_frame, **plotly_kwargs, ) From ae5655dd4ceea31f8cc4b147aaf71b158d9846d5 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:14:06 +0100 Subject: [PATCH 27/30] Fix heatmap --- flixopt/statistics_accessor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 2ac9060ac..bbd61980f 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -1523,9 +1523,11 @@ def heatmap( has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 # Count extra dims (beyond time) - if too many, skip reshape to avoid dimension explosion + # Reshape adds 1 dim (time -> timestep + timeframe), so check available slots extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] - # Max dims for heatmap: 2 axes + facet_col + animation_frame = 4 - can_reshape = len(extra_dims) <= 2 # Leave room for facet and animation + # Count available slots: 'auto' means available, None/explicit means not available + available_slots = (1 if facet_col == 'auto' else 0) + (1 if animation_frame == 'auto' else 0) + can_reshape = len(extra_dims) <= available_slots # Apply time reshape if needed (creates timestep/timeframe dims) if is_clustered and (reshape == 'auto' or reshape is None): From 56b183810d7d16e7132f478b890e2343a81c8aeb Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:16:49 +0100 Subject: [PATCH 28/30] Fix heatmap --- flixopt/statistics_accessor.py | 122 +++++++++++++++------------------ 1 file changed, 57 insertions(+), 65 deletions(-) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index bbd61980f..09d75f145 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -127,6 +127,55 @@ def _reshape_time_for_heatmap( # --- Helper functions --- +def _prepare_for_heatmap( + da: xr.DataArray, + reshape: tuple[str, str] | Literal['auto'] | None, + facet_col: str | Literal['auto'] | None, + animation_frame: str | Literal['auto'] | None, +) -> xr.DataArray: + """Prepare DataArray for heatmap: determine axes, reshape if needed, transpose/squeeze.""" + is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 + has_time = 'time' in da.dims + has_multi_vars = da.sizes.get('variable', 1) > 1 + + # Determine heatmap axes and apply reshape if needed + if is_clustered and reshape in ('auto', None): + heatmap_dims = ['time', 'cluster'] + elif reshape and reshape != 'auto' and has_time: + da = _reshape_time_for_heatmap(da, reshape) + heatmap_dims = ['timestep', 'timeframe'] + elif reshape == 'auto' and has_time and not is_clustered: + # Check if we have room for extra dims after reshaping (adds 1 dim: time -> timestep + timeframe) + extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] + available_slots = (facet_col == 'auto') + (animation_frame == 'auto') + if len(extra_dims) <= available_slots: + da = _reshape_time_for_heatmap(da, ('D', 'h')) + heatmap_dims = ['timestep', 'timeframe'] + elif has_multi_vars: + heatmap_dims = ['variable', 'time'] + else: + heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] + elif has_multi_vars: + heatmap_dims = ['variable', 'time'] + else: + heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] + + # Transpose: heatmap dims first, then others + other_dims = [d for d in da.dims if d not in heatmap_dims] + da = da.transpose(*[d for d in heatmap_dims if d in da.dims], *other_dims) + + # Squeeze single-element dims (except heatmap axes) + for dim in list(da.dims): + if dim not in heatmap_dims and da.sizes[dim] == 1: + da = da.squeeze(dim, drop=True) + + # Clear name for multiple variables (colorbar would show first var's name) + if has_multi_vars: + da = da.rename('') + + return da + + def _filter_by_pattern( names: list[str], include: FilterType | None, @@ -1503,82 +1552,25 @@ def heatmap( PlotResult with processed data and figure. """ solution = self._stats._require_solution() - if isinstance(variables, str): variables = [variables] - # Resolve flow labels to variable names - resolved_variables = self._resolve_variable_names(variables, solution) + # Resolve, select, and stack into single DataArray + resolved = self._resolve_variable_names(variables, solution) + ds = _apply_selection(solution[resolved], select) + da = xr.concat([ds[v] for v in ds.data_vars], dim=pd.Index(list(ds.data_vars), name='variable')) - ds = solution[resolved_variables] - ds = _apply_selection(ds, select) + # Prepare for heatmap (reshape, transpose, squeeze) + da = _prepare_for_heatmap(da, reshape, facet_col, animation_frame) - # Stack variables into single DataArray - variable_names = list(ds.data_vars) - dataarrays = [ds[var] for var in variable_names] - da = xr.concat(dataarrays, dim=pd.Index(variable_names, name='variable')) - - # Check if data is clustered (has cluster dimension with size > 1) - is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 - has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - - # Count extra dims (beyond time) - if too many, skip reshape to avoid dimension explosion - # Reshape adds 1 dim (time -> timestep + timeframe), so check available slots - extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] - # Count available slots: 'auto' means available, None/explicit means not available - available_slots = (1 if facet_col == 'auto' else 0) + (1 if animation_frame == 'auto' else 0) - can_reshape = len(extra_dims) <= available_slots - - # Apply time reshape if needed (creates timestep/timeframe dims) - if is_clustered and (reshape == 'auto' or reshape is None): - # Clustered data: use (time, cluster) as natural 2D heatmap axes - heatmap_dims = ['time', 'cluster'] - elif reshape and reshape != 'auto' and 'time' in da.dims: - # Non-clustered with explicit reshape: reshape time to (day, hour) etc. - da = _reshape_time_for_heatmap(da, reshape) - heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and 'time' in da.dims and not is_clustered and can_reshape: - # Auto mode for non-clustered: use default ('D', 'h') reshape only if not too many dims - da = _reshape_time_for_heatmap(da, ('D', 'h')) - heatmap_dims = ['timestep', 'timeframe'] - elif has_multiple_vars: - # Can't reshape but have multiple vars: use variable + time as heatmap axes - heatmap_dims = ['variable', 'time'] - else: - # Fallback: use first two available dimensions - available_dims = [d for d in da.dims if da.sizes[d] > 1] - heatmap_dims = available_dims[:2] if len(available_dims) >= 2 else list(da.dims)[:2] - - # Transpose so heatmap dims come first (px.imshow uses first 2 dims as y/x axes) - other_dims = [d for d in da.dims if d not in heatmap_dims] - dim_order = [d for d in heatmap_dims if d in da.dims] + other_dims - # Always transpose to ensure correct dim order (even if seemingly equal, xarray dim order matters) - da = da.transpose(*dim_order) - - # Squeeze single-element dims (except heatmap axes) to avoid 3D shape errors - for dim in list(da.dims): - if dim not in heatmap_dims and da.sizes[dim] == 1: - da = da.squeeze(dim, drop=True) - - # Clear name for multiple variables (colorbar would show first var's name) - if has_multiple_vars: - da = da.rename('') - - # Let fxplot handle slot assignment for facet/animation - fig = da.fxplot.heatmap( - colors=colors, - facet_col=facet_col, - animation_frame=animation_frame, - **plotly_kwargs, - ) + fig = da.fxplot.heatmap(colors=colors, facet_col=facet_col, animation_frame=animation_frame, **plotly_kwargs) if show is None: show = CONFIG.Plotting.default_show if show: fig.show() - reshaped_ds = da.to_dataset(name='value') if isinstance(da, xr.DataArray) else da - return PlotResult(data=reshaped_ds, figure=fig) + return PlotResult(data=da.to_dataset(name='value'), figure=fig) def flows( self, From 56719e8d278b51d81261df82f5794d9e6b28cd05 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:25:21 +0100 Subject: [PATCH 29/30] Fix heatmap --- flixopt/statistics_accessor.py | 63 ++++++++++++++++------------------ 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 09d75f145..e3581e4e3 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -134,46 +134,41 @@ def _prepare_for_heatmap( animation_frame: str | Literal['auto'] | None, ) -> xr.DataArray: """Prepare DataArray for heatmap: determine axes, reshape if needed, transpose/squeeze.""" + + def finalize(da: xr.DataArray, heatmap_dims: list[str]) -> xr.DataArray: + """Transpose, squeeze, and clear name if needed.""" + other = [d for d in da.dims if d not in heatmap_dims] + da = da.transpose(*[d for d in heatmap_dims if d in da.dims], *other) + for dim in [d for d in da.dims if d not in heatmap_dims and da.sizes[d] == 1]: + da = da.squeeze(dim, drop=True) + return da.rename('') if da.sizes.get('variable', 1) > 1 else da + + def fallback_dims() -> list[str]: + """Default dims: (variable, time) if multi-var, else first 2 dims with size > 1.""" + if da.sizes.get('variable', 1) > 1: + return ['variable', 'time'] + dims = [d for d in da.dims if da.sizes[d] > 1][:2] + return dims if len(dims) >= 2 else list(da.dims)[:2] + is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 has_time = 'time' in da.dims - has_multi_vars = da.sizes.get('variable', 1) > 1 - # Determine heatmap axes and apply reshape if needed + # Clustered: use (time, cluster) as natural 2D if is_clustered and reshape in ('auto', None): - heatmap_dims = ['time', 'cluster'] - elif reshape and reshape != 'auto' and has_time: - da = _reshape_time_for_heatmap(da, reshape) - heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and has_time and not is_clustered: - # Check if we have room for extra dims after reshaping (adds 1 dim: time -> timestep + timeframe) - extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] - available_slots = (facet_col == 'auto') + (animation_frame == 'auto') - if len(extra_dims) <= available_slots: - da = _reshape_time_for_heatmap(da, ('D', 'h')) - heatmap_dims = ['timestep', 'timeframe'] - elif has_multi_vars: - heatmap_dims = ['variable', 'time'] - else: - heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] - elif has_multi_vars: - heatmap_dims = ['variable', 'time'] - else: - heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] - - # Transpose: heatmap dims first, then others - other_dims = [d for d in da.dims if d not in heatmap_dims] - da = da.transpose(*[d for d in heatmap_dims if d in da.dims], *other_dims) - - # Squeeze single-element dims (except heatmap axes) - for dim in list(da.dims): - if dim not in heatmap_dims and da.sizes[dim] == 1: - da = da.squeeze(dim, drop=True) + return finalize(da, ['time', 'cluster']) + + # Explicit reshape: always apply + if reshape and reshape != 'auto' and has_time: + return finalize(_reshape_time_for_heatmap(da, reshape), ['timestep', 'timeframe']) - # Clear name for multiple variables (colorbar would show first var's name) - if has_multi_vars: - da = da.rename('') + # Auto reshape (non-clustered): apply only if extra dims fit in available slots + if reshape == 'auto' and has_time: + extra = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] + slots = (facet_col == 'auto') + (animation_frame == 'auto') + if len(extra) <= slots: + return finalize(_reshape_time_for_heatmap(da, ('D', 'h')), ['timestep', 'timeframe']) - return da + return finalize(da, fallback_dims()) def _filter_by_pattern( From c6da15f800eb1f2c459b593c1b73bfa10452595b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:57:51 +0100 Subject: [PATCH 30/30] Squeeze signleton dims in heatmap() --- flixopt/clustering/base.py | 5 +++++ flixopt/dataset_plot_accessor.py | 26 ++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index ab9590aae..0f154484b 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -827,6 +827,11 @@ def heatmap( heatmap_da = heatmap_da.assign_coords(y=['Cluster']) heatmap_da.name = 'cluster_assignment' + # Reorder dims so 'time' and 'y' are first (heatmap x/y axes) + # Other dims (period, scenario) will be used for faceting/animation + target_order = ['time', 'y'] + [d for d in heatmap_da.dims if d not in ('time', 'y')] + heatmap_da = heatmap_da.transpose(*target_order) + # Use fxplot.heatmap for smart defaults fig = heatmap_da.fxplot.heatmap( colors=colors, diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 47cb0564a..6c833e652 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -525,7 +525,6 @@ def heatmap( resolved_animation = animation_frame imshow_args: dict[str, Any] = { - 'img': da, 'color_continuous_scale': colors, 'title': title or variable, } @@ -538,6 +537,18 @@ def heatmap( if resolved_animation and resolved_animation in da.dims: imshow_args['animation_frame'] = resolved_animation + # Squeeze singleton dimensions not used for faceting/animation + # px.imshow can't handle extra singleton dims in multi-dimensional data + dims_to_preserve = set(list(da.dims)[:2]) # First 2 dims are heatmap x/y axes + if resolved_facet and resolved_facet in da.dims: + dims_to_preserve.add(resolved_facet) + if resolved_animation and resolved_animation in da.dims: + dims_to_preserve.add(resolved_animation) + for dim in list(da.dims): + if dim not in dims_to_preserve and da.sizes[dim] == 1: + da = da.squeeze(dim) + imshow_args['img'] = da + # Use binary_string=False to handle non-numeric coords (e.g., string labels) if 'binary_string' not in imshow_kwargs: imshow_args['binary_string'] = False @@ -955,7 +966,6 @@ def heatmap( resolved_animation = animation_frame imshow_args: dict[str, Any] = { - 'img': da, 'color_continuous_scale': colors, 'title': title or (da.name if da.name else ''), } @@ -968,6 +978,18 @@ def heatmap( if resolved_animation and resolved_animation in da.dims: imshow_args['animation_frame'] = resolved_animation + # Squeeze singleton dimensions not used for faceting/animation + # px.imshow can't handle extra singleton dims in multi-dimensional data + dims_to_preserve = set(list(da.dims)[:2]) # First 2 dims are heatmap x/y axes + if resolved_facet and resolved_facet in da.dims: + dims_to_preserve.add(resolved_facet) + if resolved_animation and resolved_animation in da.dims: + dims_to_preserve.add(resolved_animation) + for dim in list(da.dims): + if dim not in dims_to_preserve and da.sizes[dim] == 1: + da = da.squeeze(dim) + imshow_args['img'] = da + # Use binary_string=False to handle non-numeric coords (e.g., string labels) if 'binary_string' not in imshow_kwargs: imshow_args['binary_string'] = False