Skip to content

Commit e6411fa

Browse files
committed
Add colormap_from_single_color
[skip ci]
1 parent 4cfa975 commit e6411fa

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

src/simdec/visualization.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
from typing import Literal
33

4+
import colorsys
45
import matplotlib as mpl
56
import matplotlib.pyplot as plt
67
import numpy as np
@@ -11,7 +12,7 @@
1112
__all__ = ["visualization", "tableau", "palette"]
1213

1314

14-
sequential_palettes = [
15+
SEQUENTIAL_PALETTES = [
1516
"Oranges",
1617
"Purples",
1718
"Reds",
@@ -33,6 +34,37 @@
3334
]
3435

3536

37+
def colormap_from_single_color(rgba_color, *, factor=0.5):
38+
"""Create a linear colormap using a single color."""
39+
# discard alpha channel
40+
if len(rgba_color) == 4:
41+
*rgb_color, alpha = rgba_color
42+
else:
43+
alpha = 1
44+
rgb_color = rgba_color
45+
46+
# lighten and darken from factor around single color
47+
hls_color = colorsys.rgb_to_hls(*rgb_color)
48+
49+
lightness = hls_color[1]
50+
lightened_hls_color = (hls_color[0], lightness * (1 + factor), hls_color[2])
51+
lightened_rgb_color = list(colorsys.hls_to_rgb(*lightened_hls_color))
52+
53+
darkened_hls_color = (hls_color[0], lightness * (1 - factor), hls_color[2])
54+
darkened_rgb_color = list(colorsys.hls_to_rgb(*darkened_hls_color))
55+
56+
lightened_rgba_color = lightened_rgb_color + [alpha]
57+
darkened_rgba_color = darkened_rgb_color + [alpha]
58+
59+
# convert to CMAP
60+
cmap = mpl.colors.LinearSegmentedColormap.from_list(
61+
"CustomSingleColor",
62+
[lightened_rgba_color, rgba_color, darkened_rgba_color],
63+
N=3,
64+
)
65+
return cmap
66+
67+
3668
def palette(states: list[int]) -> list[list[float]]:
3769
"""Colour palette.
3870
@@ -54,7 +86,7 @@ def palette(states: list[int]) -> list[list[float]]:
5486
# many levels
5587
n_shades = int(np.prod(states[1:]))
5688
for i in range(states[0]):
57-
palette_ = sequential_palettes[i]
89+
palette_ = SEQUENTIAL_PALETTES[i]
5890
cmap = mpl.colormaps[palette_].resampled(n_shades + 1)
5991
colors.append(cmap(range(1, n_shades + 1)))
6092

0 commit comments

Comments
 (0)