Tensor 默认是支持[]
操作符的,因此可以使用这样的方式来获取元素:
auto foo = torch::randn({1, 2, 3, 4});
float value = foo[0][1][2][2];
另一种方式是用Tensor对象的index
函数,它的优势是支持slice。 对于单个元素,可以类似Pytorch中,直接用index({i, j, k})
的方式来索引:
auto foo = torch::randn({1, 2, 3, 4});
float value = foo.index({0, 1, 2, 2});
那么python中很常用的slice呢?例如foo[..., :2, 1:, :-1]
,该怎么在Libtorch中表示? 这里需要用到torch::indexing::Slice
对象,来实现Python中的Slice,看看下面的例子你就明白了:
using namespace torch::indexing;
auto foo = torch::randn({1, 2, 3, 4});
// 等效于Python中的foo[:, 0:1, 2:, :-1]
auto bar = foo.index({Slice(), Slice(0, 1), Slice(2, None), Slice(None, -1)});
应该是能满足Python中slice同样的使用场景。
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
添加我为好友,拉您入交流群!
请使用微信扫一扫!