老饼讲解-深度学习 机器学习 神经网络 深度学习
pytorch入门
1.pytorch学前准备
2.pytorch的基础操作
3.pytorch的梯度计算
4.pytorch的数据
5.pytorch-小试牛刀

【例子】pytorch-tensor的条件索引

作者 : 老饼 日期 : 2023-07-28 10:50:36 更新 : 2024-01-19 08:13:32
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



在pytorch的应用中,往往需要找出tensor中符合条件的元素进行修改

本文讲解如何在找出tensor中符合条件的元素的索引,

并展示怎么利用索引对tensor的数据进一步进行修改





一、pytorch的tensor查找符合条件元素的索引   



本节讲解pytorch找出tensor中符合条件的元素的索引的两种方法



1.查找tensor中符合条件的元素的索引-where方法


pytorch可以使用where函数找出tensor中符合条件的元素的索引

示例如下:

import torch
torch.manual_seed(99)  
t = torch.arange(0, 9).view(3, 3)        # 生成一个3*3的tensor
idx    = torch.where(t>4)                # 找出tensor中符合>4的元素的索引
print('t:',t)                            # 打印tensor
print('符合条件>4的索引idx:',idx)           # 打印索引

运行结果如下:

t: tensor([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]])
符合条件>4的索引idx: (tensor([1, 2, 2, 2]), tensor([2, 0, 1, 2]))

✍️解说:

上述的where函数返回的索引分为多个tensor, 

 idx[0]=[1,2,2,2]代表符合条件的元素的第0维索引,

 idx[1]=[2,0,1,2]代表符合条件的元素的第1维索引,

即符合条件>4的元素索引为:[1,2],[2,0],[2,1],[2,2]



2.查找tensor中符合条件的元素的索引-argwhere方法


pytorch也可以使用argwhere函数找出tensor中符合条件的元素的索引

import torch
torch.manual_seed(99)  
t = torch.arange(0, 9).view(3, 3)        # 生成一个3*3的tensor
idx    = torch.argwhere(t>4)                # 找出tensor中符合>4的元素的索引
print('t:',t)                            # 打印tensor
print('符合条件>4的索引idx:',idx)           # 打印索引

运行结果如下:

t: tensor([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]])
符合条件>4的索引idx: 
tensor([[1, 2],
       [2, 0],
       [2, 1],
       [2, 2]])





二. pytorch的tensor如何根据条件修改数据    



本节讲解pytorch如何修改tensor中符合条件的元素的数据

 


1.根据条件修改tensor的值


往往需要修改矩阵中符合某个条件的数据,

那么在pytorch的tensor中可以借助where函数来实现这一效果

下面展示一个常用的条件索引的例子,用于学习和借鉴

如下述例子,将tensor中大于4的元素全部改成999:

import torch
torch.manual_seed(99)  
t = torch.arange(0, 9).view(3, 3)        # 生成一个3*3的tensor
print('修改前的t:',t)                     # 打印修改前的tensor
idx    = torch.where(t>4)                # 找出tensor中符合>4的元素的索引
print('符合条件的索引idx:',idx)           # 打印索引
t[idx] = 999                             # 根据索引修改tensor的值                    
print('修改后的t:',t)                    # 打印修改后的tensor

运行结果如下:

修改前的t: tensor([[0, 1, 2],
                  [3, 4, 5],
                  [6, 7, 8]])
符合条件的索引idx: (tensor([1, 2, 2, 2]), tensor([2, 0, 1, 2]))
修改后的t: tensor([[  0,   1,   2],
        [  3,   4, 999],
        [999, 999, 999]])



2.根据条件填充tensor的值


往往只是需要简单地根据条件填充tensor的值,

那么可以在where中加入填充条件,具体如下

示例如下:

import torch
torch.manual_seed(99)  
t = torch.arange(0, 9).view(3, 3)     # 生成一个3*3的tensor
x = torch.where(t>4,1,0)              # t>4是条件,成立则填充为1,否则为0
print(t)
print(x)

运行结果如下

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[0, 0, 0],
        [0, 0, 1],
        [1, 1, 1]])







好了,以上就是pytorch中tensor的条件索引以及根据条件修改tensor的方法了~







 End 





联系老饼