Imagine you have a segmentation map, where each object is identified by a unique index, e.g. looking similar to this:
For each object, I would like to save which pixels it covers, but I could only come up with the standard for
loop so far. Unfortunately, for larger images with thousands of individual objects, this turns out to be very slow–for my real data at least. Can I somehow speed things up?
import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt from skimage.draw import random_shapes # please ignore that this does not always produce 20 objects each with a # unique color. it is simply a quick way to produce data that is similar to # my problem that can also be visualized. segmap, labels = random_shapes( (100, 100), 20, min_size=6, max_size=20, multichannel=False, intensity_range=(0, 20), num_trials=100, ) segmap = np.ma.masked_where(segmap == 255, segmap) object_idxs = np.unique(segmap)[:-1] objects = np.empty(object_idxs.size, dtype=[('idx', 'i4'), ('pixels', 'O')]) # important bit here: # this I can vectorize objects['idx'] = object_idxs # but this I cannot. and it takes forever. for i in range(object_idxs.size): objects[i]['pixels'] = np.where(segmap == i) # just plotting here fig, ax = plt.subplots(constrained_layout=True) image = ax.imshow( segmap, cmap='tab20', norm=mpl.colors.Normalize(vmin=0, vmax=20) ) fig.colorbar(image) fig.show()
Advertisement
Answer
Using np.where
in a loop is not efficient algorithmically since the time complexity is O(s n m)
where s = object_idxs.size
and n, m = segmap.shape
. This operation can be done in O(n m)
.
One solution using Numpy is to first select all the object pixel locations, then sort them based on their associated object in segmap
, and finally split them based on the number of objects. Here is the code:
background = np.max(segmap) mask = segmap != background objects = segmap[mask] uniqueObjects, counts = np.unique(objects, return_counts=True) ordering = np.argsort(objects) i, j = np.where(mask) indices = np.vstack([i[ordering], j[ordering]]) indicesPerObject = np.split(indices, counts.cumsum()[:-1], axis=1) objects = np.empty(uniqueObjects.size, dtype=[('idx', 'i4'), ('pixels', 'O')]) objects['idx'] = uniqueObjects for i in range(uniqueObjects.size): # Use `tuple(...)` to get the exact same type as the initial code here objects[i]['pixels'] = tuple(indicesPerObject[i]) # In case the conversion to tuple is not required, the loop can also be accelerated: # objects['pixels'] = indicesPerObject