【PYG】dataloader和densedataloader

news/2024/7/8 10:34:24 标签: python, pytorch, 深度学习

DenseDataLoader 是专门用于处理稠密图数据的,而 DataLoader 通常用于处理稀疏图数据。两者的主要区别在于它们的输入数据格式和处理方式。DenseDataLoader 适合处理固定大小的邻接矩阵和节点特征矩阵的数据,而 DataLoader 更加灵活,可以处理稀疏表示的图数据。

主要区别

  • DataLoader:

    • 适合处理稀疏图数据。
    • 通常与 torch_geometric.data.Data 一起使用,其中边索引是稀疏表示的。
    • 更加灵活,适合处理各种不同形状和大小的图。
  • DenseDataLoader:

    • 适合处理稠密图数据。
    • 通常与固定大小的邻接矩阵和节点特征矩阵一起使用。
    • 更高效地处理固定大小的图数据。

使用示例

使用 DenseDataLoader

如果你有固定大小的邻接矩阵和节点特征矩阵,可以直接使用 DenseDataLoader 加载数据:

1. 导入必要的库
python">import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader
2. 定义数据集类
python">class MyDenseDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, num_nodes, num_node_features):
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features
        self.adj_matrix = self.create_adj_matrix(num_nodes)

    def create_adj_matrix(self, num_nodes):
        # 创建环形图的邻接矩阵
        adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
        for i in range(num_nodes):
            adj_matrix[i, (i + 1) % num_nodes] = 1
            adj_matrix[(i + 1) % num_nodes, i] = 1
        return adj_matrix

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 创建随机特征和标签
        x = torch.randn((self.num_nodes, self.num_node_features))
        y = torch.randn((self.num_nodes, 1))  # 每个节点一个标签
        return Data(x=x, adj=self.adj_matrix, y=y)
3. 创建数据集和封装数据
python"># 参数设置
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数

# 创建数据集
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)
4. 使用 DenseDataLoader
python"># 使用 DenseDataLoader 加载数据
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)

# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:
    print("Batch node features shape:", data.x.shape)  # 期望输出形状为 (32, 10, 8)
    print("Batch adjacency matrix shape:", data.adj.shape)  # 期望输出形状为 (32, 10, 10)
    print("Batch labels shape:", data.y.shape)  # 期望输出形状为 (32, 10, 1)
    break  # 仅查看第一个批次的形状

解释

  1. 导入库

    • 导入 torchtorch_geometric.data 中的 Datatorch_geometric.loader 中的 DenseDataLoader
  2. 定义 MyDenseDataset

    • __init__ 方法初始化数据集参数,并创建邻接矩阵。
    • create_adj_matrix 方法创建环形图的邻接矩阵。
    • __len__ 方法返回数据集的样本数量。
    • __getitem__ 方法生成每个样本的随机节点特征和标签,并返回节点特征矩阵、邻接矩阵和标签。
  3. 创建数据集

    • 使用 MyDenseDataset 类创建一个包含 100 个样本的数据集,每个样本包含 10 个节点,每个节点有 8 个特征。
  4. 使用 DenseDataLoader

    • 使用 DenseDataLoader 加载 dataset,设置批次大小为 32,并进行随机打乱。
    • 在获取一个批次的数据时,检查 xadjy 的形状,以确保其符合期望的三维形状。

通过这个完整的示例代码,你可以生成、封装和加载稠密图数据,并确保每个批次的数据形状保持正确。这种方法适合处理节点数和边数固定的图数据,提高数据加载和处理的效率。

定义数据集类并使用 DenseDataLoader

python">import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader  # 更新导入路径

class MyDenseDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, num_nodes, num_node_features):
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features
        self.adj_matrix = self.create_adj_matrix(num_nodes)

    def create_adj_matrix(self, num_nodes):
        # 创建环形图的邻接矩阵
        adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
        for i in range(num_nodes):
            adj_matrix[i, (i + 1) % num_nodes] = 1
            adj_matrix[(i + 1) % num_nodes, i] = 1
        print(adj_matrix)
        return adj_matrix

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 创建随机特征和标签
        x = torch.randn((self.num_nodes, self.num_node_features))
        y = torch.randn((self.num_nodes, 1))  # 每个节点一个标签
        return Data(x, self.adj_matrix, y=y)

# 创建数据集
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)

# 使用 DenseDataLoader 加载数据
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)

# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:
    print("Batch node features shape:", data.x.shape)  # 期望输出形状为 (32, 10, 8)
    # print("Batch adjacency matrix shape:", data.adj.shape)  # 期望输出形状为 (32, 10, 10)
    print("Batch labels shape:", data.y.shape)  # 期望输出形状为 (32, 10, 1)
    break  # 仅查看第一个批次的形状

使用 DataLoader

如果你使用的是 DataLoader,则数据应当是 torch_geometric.data.Data 对象,并将数据封装在列表中:

python">import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader  # 更新导入路径

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, num_nodes, num_node_features):
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        x = torch.randn(self.num_nodes, self.num_node_features)
        edge_index = torch.tensor([[i, (i + 1) % self.num_nodes] for i in range(self.num_nodes)], dtype=torch.long).t().contiguous()
        y = torch.randn(self.num_nodes, 1)
        return Data(x=x, edge_index=edge_index, y=y)

# 创建数据集
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数
dataset = MyDataset(num_samples, num_nodes, num_node_features)

# 使用 DataLoader 加载数据
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 迭代加载数据
for batch in loader:
    print("Batch node features shape:", batch.x.shape)  # 期望输出形状为 (320, 8)
    print("Batch edge index shape:", batch.edge_index.shape)

总结

  • DenseDataLoader:处理固定大小的邻接矩阵和节点特征矩阵的数据,__getitem__ 返回Data(x, adj, y)。
  • DataLoader:处理 torch_geometric.data.Data 对象,__getitem__ 返回一个 Data 对象。

确保数据格式与使用的加载器相匹配,以避免属性错误和其他兼容性问题。


http://www.niftyadmin.cn/n/5537088.html

相关文章

外挂级OCR神器:免费文档解析、表格识别、手写识别、古籍识别、PDF转Word

TextIn Tools是一款免费的在线OCR工具,支持快速准确的文字和表格识别,手写、古籍识别,提供PDF转Markdown大模型辅助工具,同时支持PDF、WORD、EXCEL、JPG、PPT等各类格式文件的转化。 TextIn Tools特点 免费:所有产品提…

什么是 qobject_cast?

前言 在 C++ 中,类型转换是一项常见的操作,比如将 int 转换为 char 或将 QString 用于 QMessageBox。但是,为什么我们需要将一个类转换为另一个类呢?本文将解释 qobject_cast 是什么,它的作用以及为什么需要类型转换。 dynamic_cast 和 qobject_cast 的概述 什么是 dyn…

如何设计一个峰值电流可以100A的PCB?

目录 01.PCB上走线 那我们要选什么样的可以通过100A呢? 02.接线柱 03.定做铜排 04.特殊工艺 通常的PCB设计电流都不会超过10 A,甚至5 A。尤其是在家用、消费级电子中,通常PCB上持续的工作电流不会超过2 A。但是最近要给公司的产品设计动…

程序员,去哪个城市工作更幸福?

深漂、沪漂、京漂、杭漂……又是一年毕业季,作为CS专业or新手程序员会选择什么城市工作呢?希望这篇文章给各位一些参考。 根据拉勾招聘大数据研究院的数据显示,超六成程序员集中在一线城市,其中北京19%,深圳16%&#x…

element ui 的 el-date-picker 日期选择组件设置可选日期范围

有时候,在使用日历控件的时候,我们需要进行定制,控制用户只能在指定日期范围内进行日期选择,在这里,我使用了 element ui 的 el-date-picker 日期选择控件,控制只能选择当前月及往前的2个月,效果…

程序员需要具备的核心竞争力

随着IT人才的饱和,互联网就业形势越严峻。 作为一名工程师,需要具备哪些基本素养与能力,才能够应对这样的就业环境? 按照优先级排序如下: 1 业务理解、需求沟通能力 业务理解与需求沟通看似是技术经理、架构师需要…

【vsCode】如何开发一个vscode插件

开发一个VSCode插件涉及多个步骤,包括项目初始化、编写代码、调试运行以及打包发布。以下是一个简化的指南,帮助你开始开发VSCode插件的旅程: PS:首先要确保您的系统上安装了Node.js(最好是v18以上版本)、npm和VS Code。最后&…

【APK】Unity出android包,报错 Gradle build failed.See the Console for details

参考大佬的博客:报错:Gradle build failed.See the Console for details.(已解决)_starting a gradle daemon, 1 incompatible daemon co-CSDN博客 本地出Android包,Build失败 解决办法: 1.下载一个低版本…