Linux 拨号vps windows公众号手机端

pytorch中怎么制作自己的数据集

lewis 5年前 (2020-02-05) 阅读数 12 #大数据
文章标签 pytorch

要在PyTorch中制作自己的数据集,你需要创建一个继承自torch.utils.data.Dataset的自定义数据集类。这个类需要实现__len____getitem__方法。

下面是一个简单的例子,展示了如何创建一个自定义数据集类:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return x, y

在这个例子中,CustomDataset类接受两个参数datatargets,分别代表数据和对应的标签。__len__方法返回数据集的长度,__getitem__方法根据给定的索引返回对应的数据和标签。

接下来,你可以实例化这个自定义数据集类并将其用于创建一个DataLoader对象,从而可以方便地迭代数据集进行训练或测试:

data = [...]  # your data
targets = [...]  # your targets

custom_dataset = CustomDataset(data, targets)
dataloader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)

现在你可以使用dataloader来迭代自定义数据集进行训练。

版权声明

本文仅代表作者观点,不代表米安网络立场。

发表评论:

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

热门