Contents

Python NumPy For Your Grandma - 6.2 einsum()

In this section, we’ll see how you can calculate einstein sums with the einsum() function, to quickly and efficiently multiply and sum arrays.

Let’s start by looking at a simple example. Here, we have 1d arrays A and B, each with four elements.

import numpy as np

A = np.arange(4)
print(A)
## [0 1 2 3]
B = A + 4
print(B)
## [4 5 6 7]

If we run np.einsum('i,j->i', A, B), we get back the array [ 0, 22, 44, 66].

np.einsum('i,j->i', A, B)
## array([ 0, 22, 44, 66])

So, what hell the just happened? In pseudocode, you could describe what just happened as the following.

initialize output as an array of 0s the same size as A  
for each i  
  for each j  
    output_i += Ai * Bj 

The first parameter of einsum() is the subscript string. Assuming we’re operating on two arrays A and B, it always has the form “subscripts for A's axes comma subscripts for B's axes arrow subscripts for output axes”. In this case, A has one dimension so we give it the letter i, B has one dimension, so we give it the letter j, and by using the letter i for the output array, we’re saying “the output array has one dimension and it’s the same length as A”. In other words, element Ai will always feed into a corresponding output_i.

Let’s see some more examples, and I think things’ll start to soak in. If we do np.einsum('i,j->', A, B), we get back 132.

np.einsum('i,j->', A, B)
## 132

We could rewrite this process in pseudocode as

initialize output = 0  
for each i  
  for each j  
    output += Ai * Bj

Since there’s no letters to the right of the arrow in the subscript string, that means our output will have 0 dimensions. In other words, our output will be a scalar. And like before, we iterate over each i in A and each j in B, adding Ai times Bj to the sum.

What if we do np.einsum('z,z->z', A, B)?

np.einsum('z,z->z', A, B)
## array([ 0,  5, 12, 21])

In this case, the pseudocode would be

intialize output as an array of 0s the same size as A (and B)  
for each z  
  output_z += Az * Bz

So this example is equivalent to doing the element-wise product, A times B.

How about this one? np.einsum('s,t->st', A, B)

np.einsum('s,t->st', A, B)
## array([[ 0,  0,  0,  0],
##        [ 4,  5,  6,  7],
##        [ 8, 10, 12, 14],
##        [12, 15, 18, 21]])

In this case, we get back a 4x4 array and you could write the pseudocode for the process as

initialize output as a 4x4 array of 0s  
for each s  
  for each t  
    output_st += As * Bt

einsum() really starts to shine in two dimensions. Let’s make 2x2 arrays C and D.

C = np.arange(4).reshape(2,2)
print(C)
## [[0 1]
##  [2 3]]
D = C + 4
print(D)
## [[4 5]
##  [6 7]]

If we do an einstein like np.einsum('ij,ji->', C, D), we get back 37,

np.einsum('ij,ji->', C, D)
## 37

Before we write out the pseudo code for this one, let’s dissect the subscript string and what it tells us. The bit before the first comma is ij. This tells us the first array we’re operating on has exactly two dimensions. The bit after the comma is ji. This tells us the second array we’re operating on also has exactly two dimensions. Also, the length of its first dimension matches the length of the first array’s second dimension since they both use subscript j, and the length of its second dimension matches the length of the first array’s first dimension since they both use subscript i. The bit after the arrow is empty, so we know the output will have 0 dimensions, and thus will be a scalar.

Now we can write the pseudo code as

initialize output = 0  
  for each i  
    for each j  
      output += Cij * Dji

Notice that, on the surface, this particular einstein sum is equivalent to doing np.sum(C * D.T), The difference is that einsum() only visits each element once whereas sum(C * D.T) visits each element twice - first when it does the multiplication and second when it does the sum. More importantly though, sum(C * D.T) creates a temporary array that takes up memory before it gets summed into a scalar. einsum() avoids this memory consumption, which, if you’re dealing with big arrays, can make a significant difference.


Course Curriculum

  1. Introduction
    1.1 Introduction
  2. Basic Array Stuff
    2.1 NumPy Array Motivation
    2.2 NumPy Array Basics
    2.3 Creating NumPy Arrays
    2.4 Indexing 1-D Arrays
    2.5 Indexing Multidimensional Arrays
    2.6 Basic Math On Arrays
    2.7 Challenge: High School Reunion
    2.8 Challenge: Gold Miner
    2.9 Challenge: Chic-fil-A
  3. Intermediate Array Stuff
    3.1 Broadcasting
    3.2 newaxis
    3.3 reshape()
    3.4 Boolean Indexing
    3.5 nan
    3.6 infinity
    3.7 random
    3.8 Challenge: Love Distance
    3.9 Challenge: Professor Prick
    3.10 Challenge: Psycho Parent
  4. Common Operations
    4.1 where()
    4.2 Math Functions
    4.3 all() and any()
    4.4 concatenate()
    4.5 Stacking
    4.6 Sorting
    4.7 unique()
    4.8 Challenge: Movie Ratings
    4.9 Challenge: Big Fish
    4.10 Challenge: Taco Truck
  5. Advanced Array Stuff
    5.1 Advanced Array Indexing
    5.2 View vs Copy
    5.3 Challenge: Population Verification
    5.4 Challenge: Prime Locations
    5.5 Challenge: The Game of Doors
    5.6 Challenge: Peanut Butter
  6. Final Boss
    6.1 as_strided()
    6.2 einsum()
    6.3 Challenge: One-Hot-Encoding
    6.4 Challenge: Cumulative Rainfall
    6.5 Challenge: Table Tennis
    6.6 Challenge: Where’s Waldo
    6.7 Challenge: Outer Product