My question stems from the solution provided here.
In my code below, I would like to automatically take the list of variable names, x, and assign one colour for each variable from a colour map (e.g. using get_cmap). I would also only like each variable to appear once in the legend. In this example, the variables B & H have been duplicated, where I have assigned limegreen and black to them respectively.
JavaScript
x
21
21
1
import matplotlib.pyplot as plt
2
3
x = ["A","B","B","C","D","E","H","F","G","H"]
4
5
y = [-25, -10, -5, 5, 10, 30, 35, 40, 50, 60]
6
7
w = [30, 20, 30, 25, 40, 20, 40, 40, 40, 30]
8
9
colors = ["yellow","limegreen","limegreen","green","blue","red","black","brown","grey","black"]
10
11
plt.figure(figsize=(20,10))
12
13
xticks=[]
14
for n, c in enumerate(w):
15
xticks.append(sum(w[:n]) + w[n]/2)
16
17
w_new = [i/max(w) for i in w]
18
a = plt.bar(xticks, height = y, width = w, color = colors, alpha = 0.8)
19
_ = plt.xticks(xticks, w)
20
plt.legend(a.patches, x)
21
Advertisement
Answer
Here, I am using dict and zip to get a single value of ‘x’, there are easier ways by importing additional libraries like numpy or pandas. What we are doing is custom building the matplotlib legend based on this article:
JavaScript
1
5
1
a = plt.bar(xticks, height = y, width = w, color = colors, alpha = 0.8)
2
_ = plt.xticks(xticks, w)
3
x, patches = zip(*dict(zip(x, a.patches)).items())
4
plt.legend(patches, x)
5
Output:
Details:
- Lineup x with a.patches using zip
- Assign each x as a key in dictionary with a patch, but dictionary keys are unique, so the patch for a x will be saved into the dictionary.
- Unpack the list of tuples for the items in the dictionary
- Use these as imports into plt.legend
Or you can use:
JavaScript
1
5
1
set_x = sorted(set(x))
2
xind = [x.index(i) for i in set_x]
3
set_patches = [a.patches[i] for i in xind]
4
plt.legend(set_patches, set_x)
5
Using a color map:
JavaScript
1
30
30
1
import matplotlib.pyplot as plt
2
from matplotlib.colors import ListedColormap
3
4
x = ["A","B","B","C","D","E","H","F","G","H"]
5
6
y = [-25, -10, -5, 5, 10, 30, 35, 40, 50, 60]
7
8
w = [30, 20, 30, 25, 40, 20, 40, 40, 40, 30]
9
10
col_map = plt.get_cmap('tab20')
11
12
plt.figure(figsize=(20,10))
13
14
xticks=[]
15
for n, c in enumerate(w):
16
xticks.append(sum(w[:n]) + w[n]/2)
17
18
set_x = sorted(set(x))
19
xind = [x.index(i) for i in x]
20
colors = [col_map.colors[i] for i in xind]
21
22
w_new = [i/max(w) for i in w]
23
a = plt.bar(xticks, height = y, width = w, color = colors, alpha = 0.8)
24
_ = plt.xticks(xticks, w)
25
26
set_patches = [a.patches[i] for i in xind]
27
28
#x, patches = zip(*dict(zip(x, a.patches)).items())
29
plt.legend(set_patches, set_x)
30
Output: