【问题标题】:How to use random_split with percentage split (sum of input lengths does not equal the length of the input dataset)如何使用带有百分比分割的random_split(输入长度​​的总和不等于输入数据集的长度)
【发布时间】:2022-11-05 19:50:11
【问题描述】:

我尝试使用torch.utils.data.random_split,如下所示:

import torch
from torch.utils.data import DataLoader, random_split

list_dataset = [1,2,3,4,5,6,7,8,9,10]
dataset = DataLoader(list_dataset, batch_size=1, shuffle=False)

random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(123))

然而,当我尝试这个时,我得到了错误raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

我查看了the docs,似乎我应该能够传入总和为 1 的小数,但显然它不起作用。

我也用谷歌搜索了这个错误,最接近的是this issue

我究竟做错了什么?

【问题讨论】:

    标签: pytorch dataset


    【解决方案1】:

    您可能使用的是旧版本的 PyTorch,例如 Pytorch 1.10,does not have this functionality

    要在旧版本中复制此功能,您只需复制较新版本的源代码:

    import math
    from torch import default_generator, randperm
    from torch._utils import _accumulate
    from torch.utils.data.dataset import Subset
    
    def random_split(dataset, lengths,
                     generator=default_generator):
        r"""
        Randomly split a dataset into non-overlapping new datasets of given lengths.
    
        If a list of fractions that sum up to 1 is given,
        the lengths will be computed automatically as
        floor(frac * len(dataset)) for each fraction provided.
    
        After computing the lengths, if there are any remainders, 1 count will be
        distributed in round-robin fashion to the lengths
        until there are no remainders left.
    
        Optionally fix the generator for reproducible results, e.g.:
    
        >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
        >>> random_split(range(30), [0.3, 0.3, 0.4], generator=torch.Generator(
        ...   ).manual_seed(42))
    
        Args:
            dataset (Dataset): Dataset to be split
            lengths (sequence): lengths or fractions of splits to be produced
            generator (Generator): Generator used for the random permutation.
        """
        if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
            subset_lengths: List[int] = []
            for i, frac in enumerate(lengths):
                if frac < 0 or frac > 1:
                    raise ValueError(f"Fraction at index {i} is not between 0 and 1")
                n_items_in_split = int(
                    math.floor(len(dataset) * frac)  # type: ignore[arg-type]
                )
                subset_lengths.append(n_items_in_split)
            remainder = len(dataset) - sum(subset_lengths)  # type: ignore[arg-type]
            # add 1 to all the lengths in round-robin fashion until the remainder is 0
            for i in range(remainder):
                idx_to_add_at = i % len(subset_lengths)
                subset_lengths[idx_to_add_at] += 1
            lengths = subset_lengths
            for i, length in enumerate(lengths):
                if length == 0:
                    warnings.warn(f"Length of split at index {i} is 0. "
                                  f"This might result in an empty dataset.")
    
        # Cannot verify that dataset is Sized
        if sum(lengths) != len(dataset):    # type: ignore[arg-type]
            raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
    
        indices = randperm(sum(lengths), generator=generator).tolist()  # type: ignore[call-overload]
        return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-01-11
      • 2018-10-17
      • 1970-01-01
      • 1970-01-01
      • 2020-05-12
      • 2014-04-01
      • 2019-11-29
      • 2023-04-06
      相关资源
      最近更新 更多