libtorch 常用api函数示例
libtorch 常用api函数示例torch::Tensor b = torch::argmax(output_1, 2).cpu();
// std::cout<<b<<std::endl;
b.print();
cv::Mat mask(T_height, T_width, CV_8UC1, (uchar*)b.data_ptr());
imshow("mask",mask*255);
waitKey(0);torch::Tensor a = torch::rand({2,3});
torch::Tensor aa = a.clone();
aa.masked_fill_(aa>0.5,-2);
std::cout<<a<<std::endl;
std::cout<<aa<<std::endl;0.8803 0.2387 0.8577
0.8166 0.0730 0.4682
[ Variable{2,3} ]
-2.0000 0.2387 -2.0000
-2.0000 0.0730 0.4682
[ Variable{2,3} ]
参考:
【1】libtorch 常用api函数示例https://blog.csdn.net/yang332233/article/details/106199180
页:
[1]