Intuitively Approaching Einstein Summation Notation

Warning

This post covers einstein summation notation syntax in terms of programming languages.

You can read a more neatly formatted version of this post at my blog.

Einstein summation notation (or einsum notation for short) is a handy way to write various matrix operations in a succinct, universal manner. With it, you can probably forget all the various symbols and operators there are and stick to one common syntax, that once understood, can be more intuitive.

For example, matrix multiplication can be written as ik, kj -> ij and a transpose of a matrix can be written as ij -> ji.

Let’s figure this out.

General Rules

The following are two general rules one can use to quickly write einsum notation.

Repeating letters between input arrays means that values along those axes will be multiplied together.

Omitting a letter from the output means that values along that axis will be summed.

However, I don’t find these rules intuitive, even a little confusing. Why?

Matrices have the order of row by column. A 2x3 matrix has 2 rows and 3 columns. When we perform matrix multiplication, we take the dot product of each row in the first matrix with each column in the second matrix.

However, when the einsum rules above — specifically the first rule — are used to denote matrix multiplication (ik, kj \rightarrow ij, as depicted below), the order of a matrix appears to change.

In order for the einsum rules and the definition of matrix multiplcation above to stay consistent, k now denotes the rows in the first matrix and columns in the second matrix, thereby changing the order of a matrix to column by row.

But even if we let k denote the columns in the first matrix, we end up doing dot products with each column in the first matrix and with each row in the second matrix.

Not intuitive.

A More Intuitive Way

The key to understanding einsum notation is to not think of axes, but of iterators. For example, i is an iterator that returns the rows of a matrix. j is an iterator that returns the columns of a matrix.

Let’s begin with a relatively more simple example: the hadamard product (also known as the elementwise product or elementwise multiplication.)

Hadamard Product

We have the following two matrices.

A = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}, B = \begin{bmatrix} 9 & 8 & 7 \\ 6 & 5 & 4 \\ 3 & 2 & 1 \end{bmatrix}

To access the element 8 in matrix A, we need to return the second row and first column[1]. This can be denoted as a_{21}. The first digit in the subscript refers to the row and the second refers to the column. We can refer to any entry generally as a_{ij}.

Taking the hadamard product looks like this.

\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} \odot \begin{bmatrix} 9 & 8 & 7 \\ 6 & 5 & 4 \\ 3 & 2 & 1 \end{bmatrix} = \begin{bmatrix} 1 \cdot 9 & 2 \cdot 8 & 3 \cdot 7 \\ 4 \cdot 6 & 5 \cdot 5 & 6 \cdot 4 \\ 7 \cdot 3 & 8 \cdot 2 & 9 \cdot 1 \end{bmatrix} = C

An alternative way to look at it…

Warning

Don’t dwell too much on this; it may help to refer back to this later to help understand the einsum notation below.

\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} \odot \begin{bmatrix} 9 & 8 & 7 \\ 6 & 5 & 4 \\ 3 & 2 & 1 \end{bmatrix} = \begin{bmatrix} a_{00}b_{00} & a_{01}b_{01} & a_{02}b_{02} \\ a_{10}b_{10} & a_{11}b_{11} & a_{12}b_{12} \\ a_{20}b_{20} & a_{21}b_{21} & a_{22}b_{22} \\ \end{bmatrix} = C

In words, what’s happening is that we’re looping through all the rows of A and B. For each row, we also loop through each column and multiply those columns together.

\text{for row in } A \text{ and } B

\text{for col in } A \text{ and } B

$a_{\text{rowcol}} \cdot b_{\text{rowcol}} = c_{\text{rowcol}}$

\text{for } i \text{ in } A \text{ and } B

\text{for } j \text{ in } A \text{ and } B

$a_{ij} \cdot b_{ij} = c_{ij}$

Let’s focus on that last line above.

a_{ij} \cdot b_{ij} = c_{ij}

This line represents elementwise multiplication. For each row i in A and B, we iterate through each column j in those rows, and take their product.

In einsum notation, we can more succinctly write this as ij, ij \rightarrow ij. This has 4 parts.

  1. ij refers to the iterators working on the rows and columns of Ai works on the rows and j works on the columns.
  2. ij refers to the exact same iterators working on B.
  3. \rightarrow tells us an output will be returned.
  4. ij refers to the exact same iterators that will be responsble for making up the output matrix C. The location ij says where the product of two elements will be located in C.

Matrix Multiplication

Let’s cover matrix multiplication in the same manner as above.

A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} , B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}
\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \cdot \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} (1 \cdot 5) + (2 \cdot 7) & (1 \cdot 6) + (2 \cdot 8) \\ (3 \cdot 5) + (4 \cdot 7) & (3 \cdot 6) + (4 \cdot 8) \end{bmatrix} = C

An alternative way to look at it…

Warning

Don’t dwell too much on this; it may help to refer back to this later to help understand the einsum notation below.

\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \cdot \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} (a_{00}b_{00} + a_{01}b_{10}) & (a_{00}b_{01} + a_{01}b_{11}) \\ (a_{10}b_{00} + a_{11}b_{10}) & (a_{10}b_{01} + a_{11}b_{11}) \end{bmatrix} = C

Matrix multiplication simply involves taking the dot product of each row in the first matrix with each column in the second matrix.

We’ll need to use 3 iterators for this: one iterator i to loop through the rows of A, another iterator j to loop through the columns of B, and a third iterator k to loop through the elements in a row and column.

\text{for row in } A

\text{for col in } B

$\text{for ele in } A_{\text{row}} \text{ and } B_{\text{col}}$

   $a_{\text{rowele}} \cdot b_{\text{elecol}} \mathrel{+}= c_{\text{rowcol}}$

\text{for } i \text{ in } A

\text{for } j \text{ in } B

 $\text{for } k \text{ in } A_{i} \text{ and } B_{j}$

   $a_{ik} \cdot b_{kj} \mathrel{+}= c_{ij}$

Let’s focus in on the last line above.

a_{ik} \cdot b_{kj} \mathrel{+}= c_{ij}

This can more succinctly be written in einsum notation as ik, kj \rightarrow ij — for each row i in A, and for each column j in B, iterate through each element k, take their product, and sum the those products. The location of the output of the dot product in the output matrix C is c_{ij}.

Various Examples

1D Operations

A = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, B = \begin{bmatrix} 4 \\ 5 \\ 6 \end{bmatrix}

Returning a View of A

Pseudocode

For each row i, output the row.

Einsum Notation

i \rightarrow i


Summing the Values of A

Pseudocode

Iterate through each row i, and sum all rows.

Einsum Notation

i \rightarrow \phantom{i}

Note

A scalar is output, hence no output iterator.


Hadamard Product of A and B

Pseudocode

For each row i in A and B, multiply them together.

Einsum Notation

i, i \rightarrow i


Dot Product of A and B

Pseudocode

For each row i in A and B, multiply them together, and sum the products.

Einsum Notation

i, i \rightarrow \phantom{i}

Note

A scalar is output, hence no output iterator.


Outer Product of A and B

Pseudocode

For each row i in A, multiply it with each row j in B.

Einsum Notation

i, j \rightarrow ij

Expanded example

The outer product involves multiplying each element in A with all elements in B.

\text{for row in } A

\text{for another-row in } B

 $a_{\text{row}} \cdot b_{\text{another-row}} = c_{\text{rowanother-row}}$

\text{for } i \text{ in } A

\text{for } j \text{ in } B

 $a_{i} \cdot b_{j} = c_{ij}$

2D Operations

A = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} , B = \begin{bmatrix} 9 & 8 & 7 \\ 6 & 5 & 4 \\ 3 & 2 & 1 \end{bmatrix}

Return a View of A

Pseudocode

For each row i, iterate through each column j and output it.

Einsum Notation

ij \rightarrow ij


Transpose A

Pseudocode

For each row i, iterate through each column j and output it in C at row j and column i.

Einsum Notation

ij \rightarrow ji


Return the Main Diagonal of A

Pseudocode

For each row i, iterate through each column i and output it.

Einsum Notation

ii \rightarrow i


Obtain the Trace of A

Pseudocode

For each row i, iterate through each column i and sum them.

Einsum Notation

ii \rightarrow \phantom{i}

Note

A scalar is output, hence no output iterator.


Sum the Rows of A

Pseudocode

For each row i, iterate through each column j and sum them.

Einsum Notation

ij \rightarrow j


Sum the Columns of A

Pseudocode

For each column j, iterate through each row i and sum them.

Einsum Notation

ij \rightarrow i


Hadamard Product of A and B

Pseudocode

For each row i in A and B, iterate throuch each column j, and take their product.

Einsum Notation

ij, ij \rightarrow ij


Hadamard Product of A and B Transposed (A \odot B^{T})

Pseudocode

For each row i in A, and for each row j in B, iterate through each column j in A and each column i in B, and take their product.

Einsum Notation

ij, ji \rightarrow ij


Matrix Product of A and B

Pseudocode

For each row i in A, and for each column j in B, iterate through each element k, take their product, and then sum those products.

Einsum Notation

ik, kj \rightarrow ij


Each Row of A Multiplied with B

Pseudocode

For each row i in A, and for each row j in B, iterate through each column k and take their product.

Einsum Notation

ik, jk \rightarrow ijk

Note

A three dimensional tensor is output, hence the three output iterators.


Every Element of A Multiplied with B

Pseudocode

For each row i in A, iterate through each column j and multiply it with each row k in B by iterating through each column l in that row k.

Einsum Notation

ij, kl \rightarrow ijkl

Note

A four dimensional tensor is output, hence the four output iterators.

Conclusion

And that’s that! The key is to think in terms of iterators that return locations in a matrix.

It may help to implement the operations above by yourself through pencil and paper, and in a programming languge too.

If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do let me know!


  1. This assumes the matrix is zero indexed. This means \begin{bmatrix} 1 & 2 & 3 \end{bmatrix} is the zeroth row of A. ↩︎

1 Like