Skip to content
Advertisement

Is there a way to speed up looping over numpy.where?

Imagine you have a segmentation map, where each object is identified by a unique index, e.g. looking similar to this:

enter image description here

For each object, I would like to save which pixels it covers, but I could only come up with the standard for loop so far. Unfortunately, for larger images with thousands of individual objects, this turns out to be very slow–for my real data at least. Can I somehow speed things up?

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from skimage.draw import random_shapes


# please ignore that this does not always produce 20 objects each with a
# unique color. it is simply a quick way to produce data that is similar to
# my problem that can also be visualized.
segmap, labels = random_shapes(
    (100, 100), 20, min_size=6, max_size=20, multichannel=False,
    intensity_range=(0, 20), num_trials=100,
)
segmap = np.ma.masked_where(segmap == 255, segmap)

object_idxs = np.unique(segmap)[:-1]
objects = np.empty(object_idxs.size, dtype=[('idx', 'i4'), ('pixels', 'O')])

# important bit here:

# this I can vectorize
objects['idx'] = object_idxs
# but this I cannot. and it takes forever.
for i in range(object_idxs.size):
    objects[i]['pixels'] = np.where(segmap == i)

# just plotting here
fig, ax = plt.subplots(constrained_layout=True)
image = ax.imshow(
    segmap, cmap='tab20', norm=mpl.colors.Normalize(vmin=0, vmax=20)
)
fig.colorbar(image)
fig.show()

Advertisement

Answer

Using np.where in a loop is not efficient algorithmically since the time complexity is O(s n m) where s = object_idxs.size and n, m = segmap.shape. This operation can be done in O(n m).

One solution using Numpy is to first select all the object pixel locations, then sort them based on their associated object in segmap, and finally split them based on the number of objects. Here is the code:

background = np.max(segmap)
mask = segmap != background
objects = segmap[mask]
uniqueObjects, counts = np.unique(objects, return_counts=True)
ordering = np.argsort(objects)
i, j = np.where(mask)
indices = np.vstack([i[ordering], j[ordering]])
indicesPerObject = np.split(indices, counts.cumsum()[:-1], axis=1)

objects = np.empty(uniqueObjects.size, dtype=[('idx', 'i4'), ('pixels', 'O')])
objects['idx'] = uniqueObjects
for i in range(uniqueObjects.size):
    # Use `tuple(...)` to get the exact same type as the initial code here
    objects[i]['pixels'] = tuple(indicesPerObject[i])
# In case the conversion to tuple is not required, the loop can also be accelerated:
# objects['pixels'] = indicesPerObject
Advertisement