【问题标题】:How to get dataset from prepare_data() to setup() in PyTorch Lightning如何在 PyTorch Lightning 中将数据集从 prepare_data() 获取到 setup()
【发布时间】:2021-07-30 03:05:39
【问题描述】:

我使用 PyTorch Lightning 的 DataModules 方法在 prepare_data() 方法中使用 NumPy 创建了自己的数据集。现在,我想将数据传递给setup() 方法,以拆分为训练和验证。

import numpy as np 
import pytorch_lightning as pl 
from torch.utils.data import random_split, DataLoader, TensorDataset
import torch
from torch.autograd import Variable
from torchvision import transforms

np.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class DataModuleClass(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.constant = 2
        self.batch_size = 10
        
    def prepare_data(self):
        a = np.random.uniform(0, 500, 500)
        b = np.random.normal(0, self.constant, len(a))

        c = a + b
        X = np.transpose(np.array([a, b]))
        
        # Converting numpy array to Tensor
        self.x_train_tensor = torch.from_numpy(X).float().to(device)
        self.y_train_tensor = torch.from_numpy(c).float().to(device)
        
        training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)

        return training_dataset
    
    def setup(self):
        data = # What I have to write to get the data from prepare_data()
        self.train_data, self.val_data = random_split(data, [400, 100])
        
        
    def train_dataloader(self):
        training_dataloader = setup() # Need to get the training data
        return DataLoader(self.training_dataloader)

    def val_dataloader(self):
        validation_dataloader = prepare_data() # Need to get the validation data
        return DataLoader(self.validation_dataloader)
    
obj = DataModuleClass()
print(obj.prepare_data())  

【问题讨论】:

    标签: pytorch pytorch-lightning pytorch-dataloader


    【解决方案1】:

    和你之前的问题一样的答案...

    def prepare_data(self):
        a = np.random.uniform(0, 500, 500)
        b = np.random.normal(0, self.constant, len(a))
    
        c = a + b
        X = np.transpose(np.array([a, b]))
    
        # Converting numpy array to Tensor
        self.x_train_tensor = torch.from_numpy(X).float().to(device)
        self.y_train_tensor = torch.from_numpy(c).float().to(device)
    
        training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)
    
        self.training_dataset = training_dataset
    
    def setup(self):
        data = self.training_dataset
        self.train_data, self.val_data = random_split(data, [400, 100])
        
        
    def train_dataloader(self):
        return DataLoader(self.train_data)
    
    def val_dataloader(self):
        return DataLoader(self.val_data)
    

    【讨论】:

      猜你喜欢
      • 2021-07-29
      • 2020-09-12
      • 2023-01-05
      • 2019-12-14
      • 1970-01-01
      • 2018-05-06
      • 2023-02-15
      • 2020-04-14
      • 2021-03-03
      相关资源
      最近更新 更多