Ein-what ?
In a nutshell einops
is a library that will help us get rid of those weird tensor operations which didnt make sense until we printed the shapes after each line, or added multiple comments to describe whats going on.
Instead, einops
will help us to make our tensor operations much more intuitive and readable.
x = x.reshape(x.size(0), -1) ## [B,C,H,W] becomes [B, C*H*W]
becomes
rearrange(x, 'b c h w -> b (c h w)')
Which is pretty neat if you ask me :)
Some useful ops
Repeat
Repeats a tensor over a give (or new) dim
from einops import repeat
output_tensor = repeat(input_tensor, 'h w -> h w c', c=3)
You can set c
to have any value as per your wish.
Rearranging axes
Goodbye np.moveaxis
or x.permute/x.transpose
from einops import rearrange
x = rearrange(x, 'c h w -> h w c')
In pytorch
from einops.layers.torch import Rearrange
model = Sequential(
Conv2d(3, 6, kernel_size=5),
MaxPool2d(kernel_size=2),
Conv2d(6, 16, kernel_size=5),
MaxPool2d(kernel_size=2),
# flattening
Rearrange('b c h w -> b (c h w)'), ## this is where the magic is
Linear(16*5*5, 120),
ReLU(),
Linear(120, 10),
)
Going wild with reshapes
The best part about einops is how intuitive it is to reshape tensors without worrying about getting the dims wrong:
This is how you can concat all images in a batch along the height axis:
import torch
from einops.layers.torch import Rearrange
x = torch.randn(5, 3, 100, 100)
layer = Rearrange(x, 'b h w c -> (b h) w c')
x = layer(x).numpy()
plt.imshow(x)
Now in order to do the same thing along width, all that we have to do is:
layer = Rearrange(x, 'b h w c -> (b h) w c')
I’ll add more stuff to this post after further explorations :)