163 字
1 分钟
einops库
2024-09-14

einops库基本使用#

einops库主题有三个功能:rearrange,reduce和repeat

import torch
from einops import rearrange,reduce,repeat
x = torch.randn(2,3,4,5)
# 1 transpose
out1 = x.transpose(1,2)
out2 = rearrange(x,'b i h w ->b h i w')
# 2 reshape
out1 = x.reshape(-1,4,5)
out2 = rearrange(x,'b i h w->(b i) h w')
out3 = rearrange(out2,'(b i) h w -> b i h w',b=2) # b=2
flag = torch.allclose(out3,x)
print(flag)
# 池化操作
out1 = reduce(x,'b i h w -> b i h','mean') # avg pool
out2 = reduce(x,'b i h w -> b i h 1','sum') # keep dimension
out3 = reduce(x,'b i h w-> b i','max')
# repate操作
out1 = rearrange(x,'b i h w -> b i h w 1') #extend dim torch.unsqueeze
# print(out1)
out2 = repeat(out1,'b i h w 1 -> b i h w 2') # torch.tile
out3 = repeat(x,'b i h w -> b i (2 h) (2 w)')