I would like to update dynamically a heatmap made with Seaborn, adding data line one by one. The sample code below get the basics work done, but instead of updating the heatmap it seems to nest a new one in the previous one.
Thank you in advance for any help/solution you may provide.
import numpy as np
np.random.seed(0)
import seaborn as sns
import matplotlib.pyplot as plt
data = np.random.rand(120 ,5900)
data_to_draw = np.zeros(shape=(1,5900))
for i,d in enumerate(data):
# update data to be drawn
data_to_draw = np.vstack((data_to_draw, data[i]))
#keep max 5 rows visible
if data_to_draw.shape[0]>5:
data_to_draw = data_to_draw[1:]
ax = sns.heatmap(data_to_draw,cmap="coolwarm")
plt.draw()
plt.pause(0.1)
Advertisement
Answer
I re-structured your code in order to exploit matplotlib.animation.FuncAnimation.
In order to avoid drawing a new heatmap and a new colobar in each iteration, it is necessary to specify in which axis draw each of them through ax and cbar_ax parameters in seaborn.heatmap.
Moreover, after drawing an heatmap, it is convenient to erase the previous one with ax.cla().
Complete Code
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
np.random.seed(0)
data = np.random.rand(120, 50)
data_to_draw = np.zeros(shape = (1, 50))
def animate(i):
global data_to_draw
data_to_draw = np.vstack((data_to_draw, data[i]))
if data_to_draw.shape[0] > 5:
data_to_draw = data_to_draw[1:]
ax.cla()
sns.heatmap(ax = ax, data = data_to_draw, cmap = "coolwarm", cbar_ax = cbar_ax)
grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8))
ani = FuncAnimation(fig = fig, func = animate, frames = 100, interval = 100)
plt.show()
If you want to keep your original code structure, you can apply the same principles:
import numpy as np
np.random.seed(0)
import seaborn as sns
import matplotlib.pyplot as plt
data = np.random.rand(120, 5900)
data_to_draw = np.zeros(shape = (1, 5900))
grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8))
for i, d in enumerate(data):
# update data to be drawn
data_to_draw = np.vstack((data_to_draw, data[i]))
# keep max 5 rows visible
if data_to_draw.shape[0] > 5:
data_to_draw = data_to_draw[1:]
sns.heatmap(ax = ax, data = data_to_draw, cmap = "coolwarm", cbar_ax = cbar_ax)
plt.draw()
plt.pause(0.1)

