Python NumPy For Your Grandma - 4.1 where()
In this section, we’ll see how you can use NumPy’s
where() function as a vectorized approach to writing if-else statements.
Suppose you have two 1d arrays,
bar, each with 2M elements.
import numpy as np generator = np.random.default_rng() foo = generator.integers(6, size=10**7) print(foo) ## [1 5 0 ... 4 2 3] bar = generator.integers(6, size=10**7) print(bar) ## [5 1 0 ... 1 1 1]
You want to create a third array called
baz such that, where
bar is even, you double the corresponding value of
foo, otherwise, you take half the corresponding value of
foo. You might be inclined to do this with a for loop like this.
baz = np.zeros(foo.size) for i in range(foo.size): if bar[i] % 2 == 0: baz[i] = foo[i] * 2 else: baz[i] = foo[i] / 2 print(baz) ## [0.5 2.5 0. ... 2. 1. 1.5]
You probably noticed that this technique took a while to execute. The slowdown happens because, for every iteration of the loop, Python has to call some lower level function in C, and then C has to give the result back to Python. This back and fourth communication between Python and C is the bottleneck in our implementation.
NumPy is specifically designed to address this problem. Instead of going back and fourth between C and Python, we want to say, “Hey C, here’s my array and here’s what I want to do with it. Now carry out the entire operation and come back to me with the result”. This is called vectorization, where C operates on a batch, or vector of data before giving the result back to Python. In general, if you’re writing a for loop to process some NumPy array, you’re doing it wrong.
Going back to our example, the vectorized way to solve this problem is to use NumPy’s
where() takes three main parameters:
- a boolean array
- values to use when the boolean array is True
- and values to use when the boolean array is False
So, we can solve our original problem like this.
boz = np.where(bar % 2 == 0, foo * 2, foo / 2) print(boz) ## [0.5 2.5 0. ... 2. 1. 1.5]
Let’s use the
array_equal() function to make sure baz and boz really are the same.
np.array_equal(baz, boz) ## True
If you’re in a jupyter or ipython notebook like this one, you can use the magic command
to measure the runtime of a block of code. If we do that here, you can see just how much faster the vectorized
where() function is compared to the python loop.
- Basic Array Stuff
2.1 NumPy Array Motivation
2.2 NumPy Array Basics
2.3 Creating NumPy Arrays
2.4 Indexing 1-D Arrays
2.5 Indexing Multidimensional Arrays
2.6 Basic Math On Arrays
2.7 Challenge: High School Reunion
2.8 Challenge: Gold Miner
2.9 Challenge: Chic-fil-A
- Intermediate Array Stuff
3.4 Boolean Indexing
3.8 Challenge: Love Distance
3.9 Challenge: Professor Prick
3.10 Challenge: Psycho Parent
- Common Operations
4.2 Math Functions
4.8 Challenge: Movie Ratings
4.9 Challenge: Big Fish
4.10 Challenge: Taco Truck
- Advanced Array Stuff
5.1 Advanced Array Indexing
5.2 View vs Copy
5.3 Challenge: Population Verification
5.4 Challenge: Prime Locations
5.5 Challenge: The Game of Doors
5.6 Challenge: Peanut Butter
- Final Boss
6.3 Challenge: One-Hot-Encoding
6.4 Challenge: Cumulative Rainfall
6.5 Challenge: Table Tennis
6.6 Challenge: Where’s Waldo
6.7 Challenge: Outer Product