I am playing with numpy
and digging through documentation and I have come across some magic. Namely I am talking about numpy.where()
:
>>> x = np.arange(9.).reshape(3, 3) >>> np.where( x > 5 ) (array([2, 2, 2]), array([0, 1, 2]))
How do they achieve internally that you are able to pass something like x > 5
into a method? I guess it has something to do with __gt__
but I am looking for a detailed explanation.
Advertisement
Answer
How do they achieve internally that you are able to pass something like x > 5 into a method?
The short answer is that they don’t.
Any sort of logical operation on a numpy array returns a boolean array. (i.e. __gt__
, __lt__
, etc all return boolean arrays where the given condition is true).
E.g.
x = np.arange(9).reshape(3,3) print x > 5
yields:
array([[False, False, False], [False, False, False], [ True, True, True]], dtype=bool)
This is the same reason why something like if x > 5:
raises a ValueError if x
is a numpy array. It’s an array of True/False values, not a single value.
Furthermore, numpy arrays can be indexed by boolean arrays. E.g. x[x>5]
yields [6 7 8]
, in this case.
Honestly, it’s fairly rare that you actually need numpy.where
but it just returns the indicies where a boolean array is True
. Usually you can do what you need with simple boolean indexing.