PyTorch各种维度变换函数总结
介绍
本文对于PyTorch中的各种维度变换的函数进行总结,包括reshape()
、view()
、resize_()
、transpose()
、permute()
、squeeze()
、unsqeeze()
、expand()
、repeat()
函数的介绍和对比。
contiguous
区分各个维度转换函数的前提是需要了解contiguous。在PyTorch中,contiguous指的是Tensor底层一维数组的存储顺序和其元素顺序一致。
Tensor是以一维数组的形式存储的,C/C++使用行优先(按行展开)的方式,Python中的Tensor底层实现使用的是C,因此PyThon中的Tensor也是按行展开存储的,如果其存储顺序和按行优先展开的一维数组元素顺序一致,就说这个Tensor是连续(contiguous)的。
形式化定义:
对于任意的维张量,如果满足对于所有的,第维相邻元素间隔=第维相邻元素间隔第维长度的乘积,则是连续的:
- 表示第维相邻元素之间间隔的位数,称为步长,可通过
stride()
方法获得。 - 表示固定其他维度时,第维的元素数量,即第维的长度,通过
size()
方法获得。
Python中的多维张量按照行优先展开的方式存储,访问矩阵中下一个元素是通过偏移来实现的,这个偏移量称为步长(stride),比如python中,访问矩阵的同一行中的相邻元素,物理结构需要偏移1个位置,即步长为1,同一列中的两个相邻元素则步长为3。
举例说明:
1 |
|
PyTorch中的有一些操作没有真正地改变tensor的内容,只是改变了索引和元素的对应关系,操作之前和操作之后的数据是共享的。这些操作包括narrow(),view(),expand(),transpose()
等[2]。当执行这些函数时,原来语义上相邻、内存里也相邻的元素,可能会出现语义上相邻,但内存上不相邻的情况,就不连续(not contiguous)了。
PyTorch提供了两个关于contiguous的方法:
is_contiguous()
: 判断Tensor是否是连续的contiguous()
: 返回新的Tensor,重新开辟一块内存,并且是连续的
举例说明(参考[1]):
1 |
|
可以看到,t和t2共享内存中的数据。如果对t2使用contiguous()
方法,会开辟新的内存空间:
1 |
|
关于contiguous的更深入的解释可以参考[1].
view()/reshape()
view()
tensor.view()函数返回一个和tensor共享底层数据,但不同形状的tensor。使用view()
函数的要求是tensor必须是contiguous的。
用法如下:
1 |
|
reshape()
tensor.reshape()类似于tensor.contigous().view()
操作,如果tensor是连续的,则reshape()操作和view()相同,返回指定形状、共享底层数据的tensor;如果tensor是不连续的,则会开辟新的内存空间,返回指定形状的tensor,底层数据和原来的tensor是独立的,相当于先执行contigous()
,再执行view()
。
如果不在意底层数据是否使用新的内存,建议使用
reshape()
代替view()
.
resize_()
tensor.resize_()函数,返回指定形状的tensor,与reshape()
和view()
不同的是,resize_()
可以只截取tensor一部分数据,或者是元素个数大于原tensor也可以,会自动扩展新的位置。
resize_()
函数对于tensor的连续性无要求,且返回的值是共享的底层数据(同view()
),也就是说只返回了指定形状的索引,底层数据不变的。
transpose()/permute()
permute()
和transpose()
还有t()
是PyTorch中的转置函数,其中t()
函数只适用于2维矩阵的转置,是这三个函数里面最”弱”的。
transpose()
tensor.transpose(),返回tensor的指定维度的转置,底层数据共享,与view()/reshape()
不同的是,transpose()
只能实现维度上的转置,不能任意改变维度大小。
对于维度交换来说,view()/reshape()
和transpose()
有很大的区别,一定不要混用!混用了以后虽然不会报错,但是数据是乱的,血坑。
reshape()/view()
和transpose()
的区别在于对于维度改变的方式不同,前者是在存储顺序的基础上对维度进行划分,也就是说将存储的一维数组根据shape大小重新划分,而transpose()
则是真正意义上的转置,比如二维矩阵的转置。
举个例子:
1 |
|
permute()
tensor.permute()函数,以view的形式返回矩阵指定维度的转置,和transpose()
功能相同。
与transpose()
不同的是,permute()
同时对多个维度进行转置,且参数是期望的维度的顺序,而transpose()
只能同时对两个维度转置,即参数只能是两个,这两个参数没有顺序,只代表了哪两个维度进行转置。
举个例子:
1 |
|
squeeze()/unsqueeze()
squeeze()
tensor.squeeze()返回去除size为1的维度的tensor,默认去除所有size=1的维度,也可以指定去除某一个size=1的维度,并返回去除后的结果。
举个例子:
1 |
|
unsqueeze()
tensor.unsqueeze()与squeeze()
相反,是在tensor插入新的维度,插入的维度size=1,用于维度扩展。
举个例子:
1 |
|
expand()/repeat()
expand()
tensor.expand()的功能是扩展tensor中的size为1的维度,且只能扩展size=1的维度。以view的形式返回tensor,即不改变原来的tensor,只是以视图的形式返回数据。
举个例子:
1 |
|
repeat()
tensor.repeat()用于维度复制,可以将size为任意大小的维度复制为n倍,和expand()
不同的是,repeat()
会分配新的存储空间,是真正的复制数据。
举个例子:
1 |
|
如果维度size=1的时候,
repeat()
和expand()
的作用是一样的,但是expand()
不会分配新的内存,所以优先使用expand()
函数。
总结
view()/reshape()
两个函数用于将tensor变换为任意形状,本质是将所有的元素重新分配。t()/transpose()/permute()
用于维度的转置,转置和reshape()
操作是有区别的,注意区分。squeeze()/unsqueeze()
用于压缩/扩展维度,仅在维度的个数上去除/添加,且去除/添加的维度size=1。expand()/repeat()
用于数据的复制,对一个或多个维度上的数据进行复制。- 以上提到的函数仅有两种会分配新的内存空间:
reshape()
操作处理非连续的tensor时,返回tensor的copy数据会分配新的内存;repeat()
操作会分配新的内存空间。其余的操作都是返回的视图,底层数据是共享的,仅在索引上重新分配。
Reference
2. stackoverflow-pytorch-contiguous
3. PyTorch官方文档
- 本文作者:Kangshitao
- 本文链接:http://kangshitao.github.io/2020/11/21/pytorch-function/index.html
- 版权声明:本博客所有文章均采用 BY-NC-SA 许可协议,转载请注明出处!