Pytorch 实现对每个样本的feature map进行裁剪 F.grid_sample 旧城等待, 2021-06-24 16:00 1713阅读 0赞 # 生成batch N=2 ,宽高=HW=7的示例样本 import torch HW = 7 N = 2 x = torch.rand(N,3,HW,HW) x tensor([[[[ 0.6801, 0.5986, 0.8342, 0.2059, 0.6529, 0.4588, 0.6079], [ 0.3499, 0.9997, 0.4779, 0.6123, 0.8895, 0.6702, 0.1118], [ 0.9417, 0.5027, 0.0630, 0.1088, 0.1800, 0.6639, 0.9210], [ 0.2522, 0.1933, 0.2214, 0.0902, 0.9833, 0.2243, 0.4692], [ 0.8330, 0.2169, 0.3008, 0.9063, 0.3030, 0.4961, 0.2058], [ 0.1108, 0.7139, 0.1230, 0.0768, 0.0268, 0.1893, 0.5520], [ 0.3987, 0.1723, 0.0756, 0.9247, 0.2617, 0.6532, 0.1511]], [[ 0.1923, 0.0138, 0.9362, 0.3879, 0.8578, 0.2559, 0.3271], [ 0.1985, 0.3664, 0.3374, 0.6199, 0.7864, 0.3920, 0.7427], [ 0.8046, 0.9312, 0.7240, 0.3423, 0.4711, 0.4097, 0.7654], [ 0.8009, 0.3712, 0.6248, 0.7377, 0.0233, 0.9360, 0.6116], [ 0.4695, 0.6464, 0.0208, 0.2115, 0.8007, 0.7577, 0.9820], [ 0.9249, 0.9200, 0.5269, 0.3906, 0.5382, 0.8067, 0.2442], [ 0.6772, 0.6780, 0.3255, 0.9823, 0.6394, 0.4344, 0.0880]], [[ 0.3615, 0.4707, 0.1852, 0.0465, 0.8819, 0.9937, 0.7102], [ 0.0930, 0.9879, 0.3972, 0.6458, 0.3975, 0.1440, 0.4829], [ 0.9814, 0.4748, 0.7973, 0.7196, 0.6132, 0.2092, 0.0649], [ 0.8326, 0.6559, 0.2625, 0.3210, 0.0434, 0.4638, 0.6590], [ 0.5413, 0.9833, 0.1283, 0.1576, 0.2311, 0.6617, 0.3430], [ 0.7199, 0.9552, 0.3986, 0.9472, 0.5030, 0.6494, 0.8596], [ 0.3445, 0.2613, 0.8283, 0.1728, 0.3771, 0.5291, 0.4734]]], [[[ 0.8500, 0.6728, 0.6622, 0.2421, 0.7653, 0.0709, 0.4887], [ 0.5043, 0.7861, 0.6012, 0.6661, 0.9236, 0.6521, 0.3341], [ 0.5777, 1.0000, 0.5524, 0.2666, 0.2591, 0.1563, 0.1013], [ 0.7084, 0.6471, 0.3055, 0.0547, 0.4499, 0.5782, 0.6310], [ 0.4918, 0.8953, 0.1984, 0.3935, 0.4994, 0.7429, 0.4769], [ 0.7932, 0.4119, 0.0737, 0.1912, 0.9368, 0.1328, 0.9625], [ 0.1384, 0.9517, 0.1403, 0.2175, 0.0351, 0.4578, 0.0993]], [[ 0.0378, 0.3744, 0.7194, 0.3785, 0.8246, 0.3476, 0.0014], [ 0.4368, 0.5562, 0.9908, 0.4234, 0.9918, 0.4406, 0.1613], [ 0.8269, 0.7115, 0.2828, 0.8004, 0.3766, 0.5500, 0.9736], [ 0.2267, 0.8925, 0.5534, 0.7284, 0.3275, 0.4464, 0.7773], [ 0.2549, 0.2889, 0.5091, 0.9417, 0.6562, 0.8813, 0.7422], [ 0.6295, 0.9268, 0.1839, 0.8589, 0.6796, 0.3920, 0.2366], [ 0.8217, 0.6012, 0.3639, 0.3125, 0.8596, 0.0460, 0.5015]], [[ 0.6835, 0.0204, 0.4621, 0.9034, 0.9936, 0.2392, 0.1581], [ 0.0362, 0.2661, 0.0505, 0.7764, 0.4404, 0.1929, 0.2910], [ 0.8297, 0.8204, 0.8631, 0.9912, 0.2494, 0.7778, 0.0271], [ 0.7450, 0.2234, 0.3558, 0.8840, 0.0821, 0.1914, 0.6607], [ 0.1506, 0.7821, 0.4108, 0.4858, 0.0947, 0.2576, 0.6863], [ 0.9268, 0.9442, 0.8276, 0.7365, 0.8599, 0.9713, 0.7455], [ 0.4314, 0.4986, 0.2590, 0.4149, 0.9218, 0.0604, 0.6914]]]]) 对样本按通道求均值 torch.mean(x, dim=1) tensor([[[ 0.4113, 0.3610, 0.6519, 0.2134, 0.7975, 0.5695, 0.5484], [ 0.2138, 0.7846, 0.4042, 0.6260, 0.6911, 0.4021, 0.4458], [ 0.9092, 0.6362, 0.5281, 0.3903, 0.4214, 0.4276, 0.5838], [ 0.6285, 0.4068, 0.3696, 0.3830, 0.3500, 0.5414, 0.5799], [ 0.6146, 0.6155, 0.1499, 0.4251, 0.4449, 0.6385, 0.5102], [ 0.5852, 0.8630, 0.3495, 0.4715, 0.3560, 0.5485, 0.5519], [ 0.4735, 0.3705, 0.4098, 0.6933, 0.4261, 0.5389, 0.2375]], [[ 0.5238, 0.3559, 0.6146, 0.5080, 0.8611, 0.2192, 0.2161], [ 0.3258, 0.5361, 0.5475, 0.6220, 0.7853, 0.4285, 0.2621], [ 0.7447, 0.8439, 0.5661, 0.6861, 0.2951, 0.4947, 0.3673], [ 0.5600, 0.5876, 0.4049, 0.5557, 0.2865, 0.4053, 0.6897], [ 0.2991, 0.6554, 0.3728, 0.6070, 0.4168, 0.6273, 0.6352], [ 0.7832, 0.7610, 0.3618, 0.5956, 0.8255, 0.4987, 0.6482], [ 0.4638, 0.6838, 0.2544, 0.3150, 0.6055, 0.1881, 0.4307]]]) # 求解最大值位置: points = torch.argmax(torch.mean(x, dim=1).view(N, HW*HW),dim=1) points # tensor([ 14, 4]) 将最大值位置转成坐标: x_p = points / HW print(x_p) y_p = torch.fmod(points,HW) print(y_p) # x坐标 tensor([ 2, 0]) # 注意坐标形式与位置的对应关系 # y坐标 tensor([ 0, 4]) # 联合坐标 z_p = torch.cat((y_p.view(2,1),x_p.view(2,1)),dim=1).float() # 注意在F.grid_sample中我们计算的y_p才是x轴 z_p tensor([[ 0., 2.], [ 4., 0.]]) # 对坐标缩至-1,1之间: z_p = ((z_p+1)-(HW+1)/2)/((HW-1)/2) grid = z_p.unsqueeze(1).unsqueeze(1) grid tensor([[[[-1.0000, -0.3333]]], [[[ 0.3333, -1.0000]]]]) # 生成通用裁剪区域:此处生成大小3*3 step = 2/(HW-1) BOX_LEFT = 1 BOX = 2*BOX_LEFT+1 # torch.Size([Box, Box, 1]) direct = torch.linspace(-(BOX_LEFT)*step,(BOX_LEFT)*step,BOX).unsqueeze(0).repeat(BOX,1).unsqueeze(-1) direct_trans = direct.transpose(1,0) full = torch.cat([direct,direct_trans],dim=2).unsqueeze(0).repeat(N,1,1,1) full tensor([[[[-0.3333, -0.3333], [ 0.0000, -0.3333], [ 0.3333, -0.3333]], [[-0.3333, 0.0000], [ 0.0000, 0.0000], [ 0.3333, 0.0000]], [[-0.3333, 0.3333], [ 0.0000, 0.3333], [ 0.3333, 0.3333]]], [[[-0.3333, -0.3333], [ 0.0000, -0.3333], [ 0.3333, -0.3333]], [[-0.3333, 0.0000], [ 0.0000, 0.0000], [ 0.3333, 0.0000]], [[-0.3333, 0.3333], [ 0.0000, 0.3333], [ 0.3333, 0.3333]]]]) # 将通用区域和最大值坐标对应起来,注意grid_sample要求flow field在-1到1之间: full[:,:,:,0] = torch.clamp(full[:,:,:,0] + grid[:,:,:,0],-1,1) full[:,:,:,1] = torch.clamp(full[:,:,:,1] + grid[:,:,:,1],-1,1) full tensor([[[[-1.0000, -0.6667], [-1.0000, -0.6667], [-0.6667, -0.6667]], [[-1.0000, -0.3333], [-1.0000, -0.3333], # 最大值坐标点 [-0.6667, -0.3333]], [[-1.0000, 0.0000], [-1.0000, 0.0000], [-0.6667, 0.0000]]], [[[ 0.0000, -1.0000], [ 0.3333, -1.0000], [ 0.6667, -1.0000]], [[ 0.0000, -1.0000], [ 0.3333, -1.0000], # 最大值坐标 [ 0.6667, -1.0000]], [[ 0.0000, -0.6667], [ 0.3333, -0.6667], [ 0.6667, -0.6667]]]]) # 裁剪feature map torch.nn.functional.grid_sample(x,full) tensor([[[[ 0.3499, 0.3499, 0.9997], [ 0.9417, 0.9417, 0.5027], [ 0.2522, 0.2522, 0.1933]], [[ 0.1985, 0.1985, 0.3664], [ 0.8046, 0.8046, 0.9312], [ 0.8009, 0.8009, 0.3712]], [[ 0.0930, 0.0930, 0.9879], [ 0.9814, 0.9814, 0.4748], [ 0.8326, 0.8326, 0.6559]]], [[[ 0.2421, 0.7653, 0.0709], [ 0.2421, 0.7653, 0.0709], [ 0.6661, 0.9236, 0.6521]], [[ 0.3785, 0.8246, 0.3476], [ 0.3785, 0.8246, 0.3476], [ 0.4234, 0.9918, 0.4406]], [[ 0.9034, 0.9936, 0.2392], [ 0.9034, 0.9936, 0.2392], [ 0.7764, 0.4404, 0.1929]]]])
还没有评论,来说两句吧...