I am wondering how I can make a heatmap using matplotlib. I cant really explain it so ill just show you an example. here is an example grid
1 4 3 2
it is stored in a dictionary like this
grid{ '0,0' : 1 #top left '1,0' : 4 #top right '0,1' : 3 #bottom left '1,1' : 2 #bottom right }
I was wondering how I can convert this dictionary into a heat map using matplotlib.
Advertisement
Answer
Here’s a very barebones example:
import matplotlib.pyplot as plt grid = {'0,0': 1, '1,0': 4, '0,1': 3, '1,1': 2} n = int(len(grid) ** 0.5) z = [[0 for _ in range(n)] for _ in range(n)] for coord, heat in grid.items(): i, j = map(int, coord.split(",")) z[j][i] = heat fig, ax = plt.subplots() im = ax.imshow(z) for i in range(n): for j in range(n): text = ax.text(j, i, z[i][j], ha="center", va="center", color="w") fig.tight_layout() plt.show()
Note that I assume you’re plotting in a square grid (hence how n
is calculated). You’d need to wrangle your data a bit more if this isn’t the case.
The main takeaway however, is that you need to figure out a way to store your z
values as a 2D array. In this case, we used the dictionary keys to get the row, col pair where the value should go, and ended up with the array:
[[1, 4], [3, 2]]
Of course, now you’ll need to style the plot. I suggest taking a look at the docs.