pytorch中的 scatter_()函数使用和详解

news/2024/7/7 13:27:47

scatter(dim, index, src)的三个参数为:

(1)dim:沿着哪个维度进行索引

(2)index: 用来scatter的元素索引

(3)src: 用来scatter的源元素,可以使一个标量也可以是一个张量

官方给的例子为三维情况下的例子:

y = y.scatter(dim,index,src)

#则结果为:
y[ index[i][j][k]  ] [j][k] = src[i][j][k] # if dim == 0
y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
y[i][j] [ index[i][j][k] ]  = src[i][j][k] # if dim == 2

如果是二维的例子,则应该对应下面的情况:

y = y.scatter(dim,index,src)

#则:
y [ index[i][j] ] [j] = src[i][j] #if dim==0
y[i] [ index[i][j] ]  = src[i][j] #if dim==1 

我们举一个实际的例子:

import torch

x = torch.randn(2,4)
print(x)
y = torch.zeros(3,4)
y = y.scatter_(0,torch.LongTensor([[2,1,2,2],[0,2,1,1]]),x)
print(y)


#结果为:
tensor([[-0.9669, -0.4518,  1.7987,  0.1546],
        [-0.1122, -0.7998,  0.6075,  1.0192]])
tensor([[-0.1122,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.4518,  0.6075,  1.0192],
        [-0.9669, -0.7998,  1.7987,  0.1546]])


'''
scatter后:
y[ index[0][0] ] [0] = src[0][0] -> y[2][0]=-0.9669

y[ index[1][3] ] [3] = src[1][3] -> y[1][3]=1.10192

'''

#如果src为标量,则代表着将对应位置的数值改为src这个标量

那么这个函数有什么作用呢?其实可以利用这个功能将pytorch 中mini batch中的返回的label(特指[ 1,0,4,9 ],即size为[4]这样的label)转为one-hot类型的label,举例子如下:

import torch

mini_batch = 4
out_planes = 6
out_put = torch.rand(mini_batch, out_planes)
softmax = torch.nn.Softmax(dim=1)
out_put = softmax(out_put)

print(out_put)
label = torch.tensor([1,3,3,5])
one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label.unsqueeze(1),1)
print(one_hot_label)

上述的这个例子假设是一个分类问题,我设置out_planes=6,是假设总共有6类,mini_batch是我们送入的网络的每个mini_batch的样本数量,这里我们不设置网络,直接假设网络的输出为一个随机的张量 ,通常我们要对这个输出进行softmax归一化,此时就代表着其属于每个类别的概率了。说到这里都不是重点,就是为了方便理解如何使用scatter,将size为[mini_batch]的张量,转为size为[mini_batch, out_palnes]的张量,并且这个生成的张量的每个行向量都是one-hot类型的了。通过看下面的输出结果就完全能够理解了。

tensor([[0.1202, 0.2120, 0.1252, 0.1127, 0.2314, 0.1985],
        [0.1707, 0.1227, 0.2282, 0.0918, 0.1845, 0.2021],
        [0.1629, 0.1936, 0.1277, 0.1204, 0.1845, 0.2109],
        [0.1226, 0.1524, 0.2315, 0.2027, 0.1907, 0.1001]])
tensor([1, 3, 3, 5])
tensor([[0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1.]])


http://www.niftyadmin.cn/n/3657782.html

相关文章

python模拟简单的扑克牌游戏

这个代码实现的是J来家游戏,规则是这样的: 两个玩家随机平分一副扑克牌中的纸牌,然后从最上面出牌,名牌摆出,如果玩家出的牌是J,则将已经落地的名牌全部收归自己,放到自己牌的最下方&#xff0…

windows10 原版 纯净版 下载

最近在整理自己的电脑,想下载一个纯净的64位windows10系统,我就在网上搜啊,下载一个安装以后,就发现一大堆的软件,360金山毒霸什么玩意的,根本是删除都删除不掉,太可恶了,没办法&…

高并发高负载性能和解决方案资源索引

高并发高负载相关资源站点集合从LiveJournal后台发展看大规模网站性能优化方法 LiveJournal文档站点 Flickr 的开发者的 Web 应用优化技巧 Digg PHPs Scalability and Performance 一些重要的计数器 高并发高负载站点框架方案从LiveJournal后台发展看大规模网站性能优化方法 Li…

特征图先进行全局平均池化(GAP)之后进行全连接层,而不是直接在特征图上进行全连接的原因

这样做是因为经过cnn提取得到的特征图,其实包含了原始图片的空间信息(也就是位置信息),如果直接做了由特征图到特征向量的转换,会破坏空间信息,而先做了全局平均池化,再做全连接层,效…

.Net B/S结构程序资源索引

.net b/s 结构程序性能优化索引如何最大限度提高.NET的性能 如何最大限度提高.NET的性能(续) 客户端调用服务器端方法——ASP.NET AJAX(Atlas)、Anthem.NET和Ajax.NET Professional实现之小小比较 net b/s 数据库Microsoft SqlServer 2005 优化索引SQL Server性能调优入门(图文…

Unsupervised Person Re-identification by Soft Multilabel Learning 源码

2019CVPR REID oral文章 Unsupervised Person Re-identification by Soft Multilabel Learning 软多标签的无监督行人重试别源码在这了: https://github.com/t20134297/MAR

接口设计定理

接口设计定理相关文章链接: 模块分解原理探索模块分解原理与三权分立接口关系稳定原理探索前面几篇文章中讲过模块分解原理和接口关系稳定原理,这篇文章中将使用模块分解原理和接口关系稳定原理来推导一个重要的定理:接口设计定理。在讲解接口…

Pyramidal Person Re-IDentification via Multi-Loss Dynamic Training 复现代码

Pyramidal Person Re-IDentification via Multi-Loss Dynamic Training是 CVPR2019 REID相关论文,有人对其进行了复现,代码在这里了 https://github.com/t20134297/Pyramidal_Person_ReID