I would like to update dynamically a heatmap made with Seaborn, adding data line one by one. The sample code below get the basics work done, but instead of updating the heatmap it seems to nest a new one in the previous one.
Thank you in advance for any help/solution you may provide.
import numpy as np np.random.seed(0) import seaborn as sns import matplotlib.pyplot as plt data = np.random.rand(120 ,5900) data_to_draw = np.zeros(shape=(1,5900)) for i,d in enumerate(data): # update data to be drawn data_to_draw = np.vstack((data_to_draw, data[i])) #keep max 5 rows visible if data_to_draw.shape[0]>5: data_to_draw = data_to_draw[1:] ax = sns.heatmap(data_to_draw,cmap="coolwarm") plt.draw() plt.pause(0.1)
Advertisement
Answer
I re-structured your code in order to exploit matplotlib.animation.FuncAnimation
.
In order to avoid drawing a new heatmap and a new colobar in each iteration, it is necessary to specify in which axis draw each of them through ax
and cbar_ax
parameters in seaborn.heatmap
.
Moreover, after drawing an heatmap, it is convenient to erase the previous one with ax.cla()
.
Complete Code
import numpy as np import seaborn as sns import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation np.random.seed(0) data = np.random.rand(120, 50) data_to_draw = np.zeros(shape = (1, 50)) def animate(i): global data_to_draw data_to_draw = np.vstack((data_to_draw, data[i])) if data_to_draw.shape[0] > 5: data_to_draw = data_to_draw[1:] ax.cla() sns.heatmap(ax = ax, data = data_to_draw, cmap = "coolwarm", cbar_ax = cbar_ax) grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2} fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8)) ani = FuncAnimation(fig = fig, func = animate, frames = 100, interval = 100) plt.show()
If you want to keep your original code structure, you can apply the same principles:
import numpy as np np.random.seed(0) import seaborn as sns import matplotlib.pyplot as plt data = np.random.rand(120, 5900) data_to_draw = np.zeros(shape = (1, 5900)) grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2} fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8)) for i, d in enumerate(data): # update data to be drawn data_to_draw = np.vstack((data_to_draw, data[i])) # keep max 5 rows visible if data_to_draw.shape[0] > 5: data_to_draw = data_to_draw[1:] sns.heatmap(ax = ax, data = data_to_draw, cmap = "coolwarm", cbar_ax = cbar_ax) plt.draw() plt.pause(0.1)