Linux 拨号vps windows公众号手机端

PyTorch中怎么实现自定义数据集类

lewis 1年前 (2024-04-09) 阅读数 13 #大数据
文章标签 pytorch

要实现自定义数据集类,需要继承PyTorch中的Dataset类,并重写其中的两个方法:len__和__getitem。下面是一个简单的例子,演示如何实现一个自定义数据集类:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data_point = self.data[index]
        target = self.targets[index]
        
        return data_point, target

在上面的例子中,CustomDataset类接收两个参数data和targets作为初始化参数,分别表示数据和标签。然后重写了__len__方法,返回数据集的长度,重写了__getitem__方法,根据索引index返回对应的数据点和标签。

使用这个自定义数据集类的方法如下:

data = [...] # your data
targets = [...] # your targets

custom_dataset = CustomDataset(data, targets)
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)

for data, target in data_loader:
    # do something with data and target

这样就可以通过自定义数据集类来加载自己的数据集,并使用DataLoader来批量加载数据。

版权声明

本文仅代表作者观点,不代表米安网络立场。

发表评论:

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

热门