I would like to update dynamically a heatmap made with Seaborn, adding data line one by one. The sample code below get the basics work done, but instead of updating the heatmap it seems to nest a new one in the previous one.
Thank you in advance for any help/solution you may provide.
JavaScript
x
20
20
1
import numpy as np
2
np.random.seed(0)
3
import seaborn as sns
4
import matplotlib.pyplot as plt
5
6
7
data = np.random.rand(120 ,5900)
8
data_to_draw = np.zeros(shape=(1,5900))
9
for i,d in enumerate(data):
10
# update data to be drawn
11
data_to_draw = np.vstack((data_to_draw, data[i]))
12
#keep max 5 rows visible
13
if data_to_draw.shape[0]>5:
14
data_to_draw = data_to_draw[1:]
15
16
ax = sns.heatmap(data_to_draw,cmap="coolwarm")
17
18
plt.draw()
19
plt.pause(0.1)
20
Advertisement
Answer
I re-structured your code in order to exploit matplotlib.animation.FuncAnimation
.
In order to avoid drawing a new heatmap and a new colobar in each iteration, it is necessary to specify in which axis draw each of them through ax
and cbar_ax
parameters in seaborn.heatmap
.
Moreover, after drawing an heatmap, it is convenient to erase the previous one with ax.cla()
.
Complete Code
JavaScript
1
27
27
1
import numpy as np
2
import seaborn as sns
3
import matplotlib.pyplot as plt
4
from matplotlib.animation import FuncAnimation
5
6
7
np.random.seed(0)
8
9
data = np.random.rand(120, 50)
10
data_to_draw = np.zeros(shape = (1, 50))
11
12
def animate(i):
13
global data_to_draw
14
data_to_draw = np.vstack((data_to_draw, data[i]))
15
if data_to_draw.shape[0] > 5:
16
data_to_draw = data_to_draw[1:]
17
18
ax.cla()
19
sns.heatmap(ax = ax, data = data_to_draw, cmap = "coolwarm", cbar_ax = cbar_ax)
20
21
22
grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
23
fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8))
24
ani = FuncAnimation(fig = fig, func = animate, frames = 100, interval = 100)
25
26
plt.show()
27
If you want to keep your original code structure, you can apply the same principles:
JavaScript
1
24
24
1
import numpy as np
2
3
np.random.seed(0)
4
import seaborn as sns
5
import matplotlib.pyplot as plt
6
7
data = np.random.rand(120, 5900)
8
data_to_draw = np.zeros(shape = (1, 5900))
9
10
grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
11
fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8))
12
13
for i, d in enumerate(data):
14
# update data to be drawn
15
data_to_draw = np.vstack((data_to_draw, data[i]))
16
# keep max 5 rows visible
17
if data_to_draw.shape[0] > 5:
18
data_to_draw = data_to_draw[1:]
19
20
sns.heatmap(ax = ax, data = data_to_draw, cmap = "coolwarm", cbar_ax = cbar_ax)
21
22
plt.draw()
23
plt.pause(0.1)
24