【问题标题】:Quickselect algorithm for singly linked list C++单链表 C++ 的快速选择算法
【发布时间】:2020-07-01 17:12:18
【问题描述】:

我需要一种算法,它可以在线性时间复杂度 O(n) 和恒定空间复杂度 O(1) 中找到单链表的中位数。

编辑:单链表是 C 风格的单链表。不允许 stl(没有容器,没有函数,所有的 stl 都是被禁止的,例如没有 std::forward_list)。不允许移动任何其他容器(如数组)中的数字。 空间复杂度为 O(logn) 是可以接受的,因为对于我的列表来说,这实际上甚至低于 100。我也不允许使用像 nth_element 这样的 STL 函数

基本上我有类似 3 * 10^6 元素的链表,我需要在 3 秒内获得中位数,所以我不能使用排序算法对列表进行排序(这将是 O(nlogn) 并且将大概需要 10-14 秒)。

我在网上做了一些搜索,发现可以使用 quickselect 在 O(n) 和 O(1) 空间中找到 std::vector 的中值(最坏的情况是O(n^2),但很少见),例如:https://www.geeksforgeeks.org/quickselect-a-simple-iterative-implementation/

但我找不到任何可以为链表执行此操作的算法。问题是我可以使用数组索引来随机访问向量如果我想修改该算法,复杂性会更大,因为。例如,当我将pivotindex更改为左侧时,我实际上需要遍历列表以获取该新元素并走得更远(这将使我的列表中的k至少为O(kn),甚至接近O(n ^ 2)...)。

编辑 2:

我知道我有太多变量,但我一直在测试不同的东西,我仍在处理我的代码...... 我当前的代码:

#include <bits/stdc++.h>

using namespace std;

template <class T> class Node {
    public:
    T data;
    Node<T> *next;
};

template <class T> class List {
    public:
    Node<T> *first;
};

template <class T> T getMedianValue(List<T> & l) {
    Node<T> *crt,*pivot,*incpivot;
    int left, right, lung, idx, lungrel,lungrel2, left2, right2, aux, offset;
    pivot = l.first;
    crt = pivot->next;
    lung = 1;
//lung is the lenght of the linked list (yeah it's lenght in romanian...)
//lungrel and lungrel2 are the relative lenghts of the part of 
//the list I am processing, e.g: 2 3 4 in a list with 1 2 3 4 5
    right = left = 0;
    while (crt != NULL) { 
        if(crt->data < pivot->data){
            aux = pivot->data;
            pivot->data = crt->data;
            crt->data = pivot->next->data;
            pivot->next->data = aux;
            pivot = pivot->next;
            left++;
        }
        else right++;
       // cout<<crt->data<<endl;
        crt = crt->next; 
        lung++; 
    }
    if(right > left) offset = left;
//  cout<<endl;
//  cout<<pivot->data<<" "<<left<<" "<<right<<endl;
//  printList(l);
//  cout<<endl;
    lungrel = lung;
    incpivot = l.first;
   // offset = 0;
    while(left != right){
        //cout<<"parcurgere"<<endl;
        if(left > right){
            //cout<<endl;
            //printList(l);
            //cout<<endl;
            //cout<<"testleft "<<incpivot->data<<" "<<left<<" "<<right<<endl;
            crt = incpivot->next;
            pivot = incpivot;
            idx = offset;left2 = right2 = lungrel = 0;
            //cout<<idx<<endl;
            while(idx < left && crt!=NULL){
                 if(pivot->data > crt->data){
                   //  cout<<"1crt "<<crt->data<<endl;
                     aux = pivot->data;
                     pivot->data = crt->data;
                     crt->data = pivot->next->data;
                     pivot->next->data = aux;
                     pivot = pivot->next;
                     left2++;lungrel++;
                  }
                  else {
                      right2++;lungrel++;
                      //cout<<crt->data<<" "<<right2<<endl;
                  }
                  //cout<<crt->data<<endl;
                  crt = crt->next;
                  idx++;
             }
             left = left2 + offset;
             right = lung - left - 1;
             if(right > left) offset = left;
             //if(pivot->data == 18) return 18;
             //cout<<endl;
             //cout<<"l "<<pivot->data<<" "<<left<<" "<<right<<" "<<right2<<endl;
           //  printList(l);
        }
        else if(left < right && pivot->next!=NULL){
            idx = left;left2 = right2 = 0;
            incpivot = pivot->next;offset++;left++;
            //cout<<endl;
            //printList(l);
            //cout<<endl;
            //cout<<"testright "<<incpivot->data<<" "<<left<<" "<<right<<endl;
            pivot = pivot->next;
            crt = pivot->next;
            lungrel2 = lungrel;
            lungrel = 0;
           // cout<<"p right"<<pivot->data<<" "<<left<<" "<<right<<endl;
            while((idx < lungrel2 + offset - 1) && crt!=NULL){
                 if(crt->data < pivot->data){
                //     cout<<"crt "<<crt->data<<endl;
                     aux = pivot->data;
                     pivot->data = crt->data;
                     crt->data = (pivot->next)->data;
                     (pivot->next)->data = aux;
                     pivot = pivot->next;
                 //    cout<<"crt2 "<<crt->data<<endl;
                     left2++;lungrel++;
                  }
                  else right2++;lungrel++;
                  //cout<<crt->data<<endl;
                  crt = crt->next;
                  idx++;
             }
             left = left2 + left;
             right = lung - left - 1;
                 if(right > left) offset = left;
            // cout<<"r "<<pivot->data<<" "<<left<<" "<<right<<endl;
           //  printList(l);
        }
        else{
            //cout<<cmx<<endl;
            return pivot->data;
        }
    }
    //cout<<cmx<<endl;
    return pivot->data;
}
template <class T> void printList(List<T> const & l) {
    Node<T> *tmp;
    if(l.first != NULL){
        tmp = l.first;
        while(tmp != NULL){
            cout<<tmp->data<<" ";
            tmp = tmp->next;
        }
    }
}
template <class T> void push_front(List<T> & l, int x)
{
    Node<T>* tmp = new Node<T>;

    tmp->data = x;

    tmp->next = l.first;
    l.first = tmp;
}

int main(){
    List<int> l;
    int n = 0;
    push_front(l, 19);
    push_front(l, 12);
    push_front(l, 11);
    push_front(l, 101);
    push_front(l, 91);
    push_front(l, 21);
    push_front(l, 9);
    push_front(l, 6);
    push_front(l, 25);
    push_front(l, 4);
    push_front(l, 18);
    push_front(l, 2);
    push_front(l, 8);
    push_front(l, 10);
    push_front(l, 200);
    push_front(l, 225);
    push_front(l, 170);
    printList(l);
    n=getMedianValue(l);
    cout<<endl;
    cout<<n;

    return 0;
}

您对如何使快速选择适应单独列出的链接或其他可以解决我的问题的算法有任何建议吗?

【问题讨论】:

  • 对链表进行合并排序(你必须排序才能找到中位数)。您将不得不使用“自下而上”(迭代)合并排序来避免使用超过 100K 节点的递归来炸毁堆栈。您的排序将在 3 秒内完成。时间(一种 3M 节点应该需要 ~0.3 秒)。您可以搜索并且应该找到一些处理链表的示例。
  • 我已经用我当前的代码更新了帖子。该单链表是任务的一部分,我无法将该列表升级为双向链表,但我可以对列表进行更改,例如移动其中的元素。由于所需的时间复杂度(没有比较排序比 O(nlogn) 更快),我无法进行任何完整排序,我只能依赖“部分排序”,例如我试图在我的代码中实现的 quickselect。我已经测试了我的代码,它适用于具有相当复杂性的小型列表。就像一个包含 18 个元素(未排序)的列表的约 40 次操作。
  • @alexcojocaru:解决方案是否还必须能够有效地处理已排序的数据?或者在最坏的情况下(即数据已经排序时),解决方案是否允许具有 O(n^2) 的时间复杂度?我们可以假设数据是未排序的吗?
  • O(n^2) 的最差时间复杂度被排序数据接受(因为排序数据的情况在大多数情况下不会发生,因为链表中的数据大多是随机的),但是我需要 O(kn) 的平均时间复杂度,大多数情况下 k
  • @alexcojocaru:我们可以假设列表中没有重复的元素,即所有值都是唯一的吗?我问是因为如果数字都相同,取决​​于算法,这可能会导致最坏情况的时间复杂度。

标签: c++ list median quickselect nth-element


【解决方案1】:

在您的问题中,您提到您无法选择不在列表开头的枢轴,因为这需要遍历列表。如果你做对了,你只需要遍历整个列表两次:

  1. 一次用于查找列表的中间和末尾以选择一个好的枢轴(例如,使用"median-of-three" 规则)
  2. 实际排序一次

如果您不太关心选择一个好的枢轴并且您很高兴只需选择列表的第一个元素作为枢轴(这会导致最坏的情况 O(n^2) @987654322 @如果数据已经排序)。

如果您通过维护指向末尾的指针第一次遍历它时记住了列表的末尾,那么您永远不必再次遍历它以找到末尾。此外,如果您使用标准的 Lomuto partition scheme(我没有使用它,原因如下),那么您还必须维护两个指向列表的指针,它们代表标准 Lomuto 分区的 ij 索引方案。通过使用这些指针,永远不必遍历列表来访问单个元素。

此外,如果您维护一个指向每个分区的中间和结尾的指针,那么,当您以后必须对这些分区中的一个进行排序时,您将不必再次遍历该分区来找到中间和结尾。

我现在已经为链表创建了自己的QuickSelect 算法实现,我在下面发布了该算法。

既然你说链表是单链表并且不能升级为双链表,我不能使用Hoare partition scheme,因为向后迭代单链表非常昂贵。因此,我改用通常效率较低的Lomuto partition scheme

使用 Lomuto 分区方案时,通常选择第一个元素或最后一个元素作为枢轴。但是,选择其中任何一个都有一个缺点,即排序的数据将导致算法具有 O(n^2) 的最坏情况时间复杂度。这可以通过根据"median-of-three" rule选择一个pivot来防止,即从第一个元素、中间元素和最后一个元素的中间值中选择一个pivot。因此,在我的实现中,我使用的是“三中位数”规则。

此外,Lomuto 分区方案通常会创建两个分区,一个用于小于基准的值,另一个用于大于或等于基准的值。但是,如果所有值都相同,这将导致 O(n^2) 的最坏情况时间复杂度。因此,在我的实现中,我创建了三个分区,一个用于小于枢轴的值,一个用于大于枢轴的值,一个用于等于枢轴的值。

虽然这些措施并不能完全消除 O(n^2) 的最坏情况时间复杂度的可能性,但它们至少使这种可能性极小(除非输入是由恶意攻击者提供的)。为了保证 O(n) 的时间复杂度,必须使用更复杂的枢轴选择算法,例如median of medians

我遇到的一个重要问题是,对于偶数个元素,median 被定义为两个“中间”或“中间”元素的arithmetic mean。出于这个原因,我不能简单地编写类似于std::nth_element 的函数,因为例如,如果元素总数为 14,那么我将寻找第 7 和第 8 大的元素。这意味着我必须调用这样的函数两次,这将是低效的。因此,我编写了一个可以同时搜索两个“中值”元素的函数。虽然这使代码更加复杂,但与不必调用相同函数两次的优势相比,由于额外的代码复杂性而导致的性能损失应该是最小的。

请注意,尽管我的实现可以在 C++ 编译器上完美编译,但我不会将其称为教科书 C++ 代码,因为问题表明我不允许使用 C++ 标准模板库中的任何内容。因此,我的代码是 C 代码和 C++ 代码的混合体。

在下面的代码中,我只使用标准模板库(特别是函数std::nth_element)来测试我的算法并验证结果。在我的实际算法中,我没有使用任何这些函数。

#include <iostream>
#include <iomanip>
#include <cassert>

// The following two headers are only required for testing the algorithm and verifying
// the correctness of its results. They are not used in the algorithm itself.
#include <random>
#include <algorithm>

// The following setting can be changed to print extra debugging information
// possible settings:
// 0: no extra debugging information
// 1: print the state and length of all partitions in every loop iteraton
// 2: additionally print the contents of all partitions (if they are not too big)
#define PRINT_DEBUG_LEVEL 0

template <typename T>
struct Node
{
    T data;
    Node<T> *next;
};

// NOTE:
// The return type is not necessarily the same as the data type. The reason for this is
// that, for example, the data type "int" requires a "double" as a return type, so that 
// the arithmetic mean of "3" and "6" returns "4.5".
// This function may require template specializations to handle overflow or wrapping.
template<typename T, typename U>
U arithmetic_mean( const T &first, const T &second )
{
    return ( static_cast<U>(first) + static_cast<U>(second) ) / 2;
}

//the main loop of the function find_median can be in one of the following three states
enum LoopState
{
    //we are looking for one median value
    LOOPSTATE_LOOKINGFORONE,

    //we are looking for two median values, and the returned median
    //will be the arithmetic mean of the two
    LOOPSTATE_LOOKINGFORTWO,

    //one of the median values has been found, but we are still searching for
    //the second one
    LOOPSTATE_FOUNDONE
};

template <
    typename T, //type of the data
    typename U  //type of the return value
>
U find_median( Node<T> *list )
{
    //This variable points to the pointer to the first element of the current partition.
    //During the partition phase, the linked list will be broken and reassembled afterwards, so
    //the pointer this pointer points to will be nullptr until it is reassembled.
    Node<T> **pp_start = &list;

    //This pointer represents nothing more than the cached value of *pp_start and it is
    //not always valid
    Node<T> *p_start = *pp_start;

    //These pointers are maintained for accessing the middle of the list for selecting a pivot
    //using the "median-of-three" rule.
    Node<T> *p_middle;
    Node<T> *p_end;

    //result is not defined if list is empty
    assert( p_start != nullptr );

    //in the main loop, this variable always holds the number of elements in the current partition
    int num_total = 1;

    // First, we must traverse the entire linked list in order to determine the number of elements,
    // in order to calculate k1 and k2. If it is odd, then the median is defined as the k'th smallest
    // element where k = n / 2. If the number of elements is even, then the median is defined as the
    // arithmetic mean of the k'th element and the (k+1)'th element.
    // We also set a pointer to the nodes in the middle and at the end, which will be required later
    // for selecting a pivot according to the "median-of-three" rule.
    p_middle = p_start;
    for ( p_end = p_start; p_end->next != nullptr; p_end = p_end->next )
    {
        num_total++;
        if ( num_total % 2 == 0 ) p_middle = p_middle->next;
    }   

    // find out whether we are looking for only one or two median values
    enum LoopState loop_state = num_total % 2 == 0 ? LOOPSTATE_LOOKINGFORTWO : LOOPSTATE_LOOKINGFORONE;

    //set k to the index of the middle element, or if there are two middle elements, to the left one
    int k = ( num_total - 1 ) / 2;

    // If we are looking for two median values, but we have only found one, then this variable will
    // hold the value of the one we found. Whether we have found one can be determined by the state of
    // the variable loop_state.
    T val_found;

    for (;;)
    {
        //make p_start cache the value of *pp_start again, because a previous iteration of the loop
        //may have changed the value of pp_start
        p_start = *pp_start;

        assert( p_start   != nullptr );
        assert( p_middle  != nullptr );
        assert( p_end     != nullptr );
        assert( num_total != 0 );

        if ( num_total == 1 )
        {
            switch ( loop_state )
            {
            case LOOPSTATE_LOOKINGFORONE:
                return p_start->data;
            case LOOPSTATE_FOUNDONE:
                return arithmetic_mean<T,U>( val_found, p_start->data );
            default:
                assert( false ); //this should be unreachable
            }
        }

        //select the pivot according to the "median-of-three" rule
        T pivot;
        if ( p_start->data < p_middle->data )
        {
            if ( p_middle->data < p_end->data )
                pivot = p_middle->data;
            else if ( p_start->data < p_end->data )
                pivot = p_end->data;
            else
                pivot = p_start->data;
        }
        else
        {
            if ( p_start->data < p_end->data )
                pivot = p_start->data;
            else if ( p_middle->data < p_end->data )
                pivot = p_end->data;
            else
                pivot = p_middle->data;
        }

#if PRINT_DEBUG_LEVEL >= 1
        //this line is conditionally compiled for extra debugging information
        std::cout << "\nmedian of three: " << (*pp_start)->data << " " << p_middle->data << " " << p_end->data << " ->" << pivot << std::endl;
#endif

        // We will be dividing the current partition into 3 new partitions (less-than,
        // equal-to and greater-than) each represented as a linked list. Each list
        // requires a pointer to the start of the list and a pointer to the pointer at
        // the end of the list to write the address of new elements to. Also, when
        // traversing the lists, we need to keep a pointer to the middle of the list,
        // as this information will be required for selecting a new pivot in the next
        // iteration of the loop. The latter is not required for the equal-to partition,
        // as it would never be used.
        Node<T> *p_less    = nullptr, **pp_less_end    = &p_less,    **pp_less_middle    = &p_less;
        Node<T> *p_equal   = nullptr, **pp_equal_end   = &p_equal;
        Node<T> *p_greater = nullptr, **pp_greater_end = &p_greater, **pp_greater_middle = &p_greater;

        // These pointers are only used as a cache to the location of the end node.
        // Despite their similar name, their function is quite different to pp_less_end
        // and pp_greater_end.
        Node<T> *p_less_end    = nullptr;
        Node<T> *p_greater_end = nullptr;

        // counter for the number of elements in each partition
        int num_less = 0;
        int num_equal = 0;
        int num_greater = 0;

        // NOTE:
        // The following loop will temporarily split the linked list. It will be merged later.

        Node<T> *p_next_node = p_start;

        //the following line isn't necessary; it is only used to clarify that the pointers no
        //longer point to anything meaningful
        *pp_start = p_start = nullptr;

        for ( int i = 0; i < num_total; i++ )
        {
            assert( p_next_node != nullptr );

            Node<T> *p_current_node = p_next_node;
            p_next_node = p_next_node->next;

            if ( p_current_node->data < pivot )
            {
                //link node to pp_less
                assert( *pp_less_end == nullptr );
                *pp_less_end = p_less_end = p_current_node;
                pp_less_end = &p_current_node->next;
                p_current_node->next = nullptr;

                num_less++;
                if ( num_less % 2 == 0 )
                {
                    pp_less_middle = &(*pp_less_middle)->next;
                }
            }
            else if ( p_current_node->data == pivot )
            {
                //link node to pp_equal
                assert( *pp_equal_end == nullptr );
                *pp_equal_end = p_current_node;
                pp_equal_end = &p_current_node->next;
                p_current_node->next = nullptr;

                num_equal++;
            }
            else
            {
                //link node to pp_greater
                assert( *pp_greater_end == nullptr );
                *pp_greater_end = p_greater_end = p_current_node;
                pp_greater_end = &p_current_node->next;
                p_current_node->next = nullptr;

                num_greater++;
                if ( num_greater % 2 == 0 )
                {
                    pp_greater_middle = &(*pp_greater_middle)->next;
                }
            }
        }

        assert( num_total == num_less + num_equal + num_greater );
        assert( num_equal >= 1 );

#if PRINT_DEBUG_LEVEL >= 1
        //this section is conditionally compiled for extra debugging information
        {
            std::cout << std::setfill( '0' );
            switch ( loop_state )
            {
            case LOOPSTATE_LOOKINGFORONE:
                std::cout << "LOOPSTATE_LOOKINGFORONE k = " << k << "\n";
                break;
            case LOOPSTATE_LOOKINGFORTWO:
                std::cout << "LOOPSTATE_LOOKINGFORTWO k = " << k << "\n";
                break;
            case LOOPSTATE_FOUNDONE:
                std::cout << "LOOPSTATE_FOUNDONE k = " << k << " val_found = " << val_found << "\n";
            }
            std::cout << "partition lengths: ";
            std::cout <<
                std::setw( 2 ) << num_less    << " " <<
                std::setw( 2 ) << num_equal   << " " <<
                std::setw( 2 ) << num_greater << " " <<
                std::setw( 2 ) << num_total   << "\n";
#if PRINT_DEBUG_LEVEL >= 2
            Node<T> *p;
            std::cout << "less: ";
            if ( num_less > 10 )
                std::cout << "too many to print";
            else
                for ( p = p_less; p != nullptr; p = p->next ) std::cout << p->data << " ";
            std::cout << "\nequal: ";
            if ( num_equal > 10 )
                std::cout << "too many to print";
            else
                for ( p = p_equal; p != nullptr; p = p->next ) std::cout << p->data << " ";
            std::cout << "\ngreater: ";
            if ( num_greater > 10 )
                std::cout << "too many to print";
            else
                for ( p = p_greater; p != nullptr; p = p->next ) std::cout << p->data << " ";
            std::cout << "\n\n" << std::flush;
#endif
            std::cout << std::flush;
        }
#endif

        //insert less-than partition into list
        assert( *pp_start == nullptr );
        *pp_start = p_less;

        //insert equal-to partition into list
        assert( *pp_less_end == nullptr );
        *pp_less_end = p_equal;

        //insert greater-than partition into list
        assert( *pp_equal_end == nullptr );
        *pp_equal_end = p_greater;

        //link list to previously cut off part
        assert( *pp_greater_end == nullptr );
        *pp_greater_end = p_next_node;

        //if less-than partition is large enough to hold both possible median values
        if ( k + 2 <= num_less )
        {
            //set the next iteration of the loop to process the less-than partition
            //pp_start is already set to the desired value
            p_middle = *pp_less_middle;
            p_end = p_less_end;
            num_total = num_less;
        }

        //else if less-than partition holds one of both possible median values
        else if ( k + 1 == num_less )
        {
            if ( loop_state == LOOPSTATE_LOOKINGFORTWO )
            {
                //the equal_to partition never needs sorting, because all members are already equal
                val_found = p_equal->data;
                loop_state = LOOPSTATE_FOUNDONE;
            }
            //set the next iteration of the loop to process the less-than partition
            //pp_start is already set to the desired value
            p_middle = *pp_less_middle;
            p_end = p_less_end;
            num_total = num_less;
        }

        //else if equal-to partition holds both possible median values
        else if ( k + 2 <= num_less + num_equal )
        {
            //the equal_to partition never needs sorting, because all members are already equal
            if ( loop_state == LOOPSTATE_FOUNDONE )
                return arithmetic_mean<T,U>( val_found, p_equal->data );
            return p_equal->data;
        }

        //else if equal-to partition holds one of both possible median values
        else if ( k + 1 == num_less + num_equal )
        {
            switch ( loop_state )
            {
            case LOOPSTATE_LOOKINGFORONE:
                return p_equal->data;
            case LOOPSTATE_LOOKINGFORTWO:
                val_found = p_equal->data;
                loop_state = LOOPSTATE_FOUNDONE;
                k = 0;
                //set the next iteration of the loop to process the greater-than partition
                pp_start = pp_equal_end;
                p_middle = *pp_greater_middle;
                p_end = p_greater_end;
                num_total = num_greater;
                break;
            case LOOPSTATE_FOUNDONE:
                return arithmetic_mean<T,U>( val_found, p_equal->data );
            }
        }

        //else both possible median values must be in the greater-than partition
        else
        {
            k = k - num_less - num_equal;

            //set the next iteration of the loop to process the greater-than partition
            pp_start = pp_equal_end;
            p_middle = *pp_greater_middle;
            p_end = p_greater_end;
            num_total = num_greater;
        }
    }
}


// NOTE:
// The following code is not part of the algorithm, but is only intended to test the algorithm

// This simple class is designed to contain a singly-linked list
template <typename T>
class List
{
public:
    List() : first( nullptr ) {}

    // the following is required to abide by the rule of three/five/zero
    // see: https://en.cppreference.com/w/cpp/language/rule_of_three
    List( const List<T> & ) = delete;
    List( const List<T> && ) = delete;
    List<T>& operator=( List<T> & ) = delete;
    List<T>& operator=( List<T> && ) = delete;

    ~List()
    {
        Node<T> *p = first;

        while ( p != nullptr )
        {
            Node<T> *temp = p;
            p = p->next;
            delete temp;
        }
    }

    void push_front( int data )
    {
        Node<T> *temp = new Node<T>;

        temp->data = data;

        temp->next = first;
        first = temp;
    }

    //member variables
    Node<T> *first;
};

int main()
{
    //generated random numbers will be between 0 and 2 billion (fits in 32-bit signed int)
    constexpr int min_val = 0;
    constexpr int max_val = 2*1000*1000*1000;

    //will allocate array for 1 million ints and fill with random numbers
    constexpr int num_values = 1*1000*1000;

    //this class contains the singly-linked list and is empty for now
    List<int> l;
    double result;

    //These variables are used for random number generation
    std::random_device rd;
    std::mt19937 gen( rd() );
    std::uniform_int_distribution<> dis( min_val, max_val );

    try
    {
        //fill array with random data
        std::cout << "Filling array with random data..." << std::flush;
        auto unsorted_data = std::make_unique<int[]>( num_values );
        for ( int i = 0; i < num_values; i++ ) unsorted_data[i] = dis( gen );

        //fill the singly-linked list
        std::cout << "done\nFilling linked list..." << std::flush;
        for ( int i = 0; i < num_values; i++ ) l.push_front( unsorted_data[i] );

        std::cout << "done\nCalculating median using STL function..." << std::flush;

        //calculate the median using the functions provided by the C++ standard template library.
        //Note: this is only done to compare the results with the algorithm provided in this file
        if ( num_values % 2 == 0 )
        {
            int median1, median2;

            std::nth_element( &unsorted_data[0], &unsorted_data[(num_values - 1) / 2], &unsorted_data[num_values] );
            median1 = unsorted_data[(num_values - 1) / 2];
            std::nth_element( &unsorted_data[0], &unsorted_data[(num_values - 0) / 2], &unsorted_data[num_values] );
            median2 = unsorted_data[(num_values - 0) / 2];

            result = arithmetic_mean<int,double>( median1, median2 );
        }
        else
        {
            int median;

            std::nth_element( &unsorted_data[0], &unsorted_data[(num_values - 0) / 2], &unsorted_data[num_values] );
            median = unsorted_data[(num_values - 0) / 2];

            result = static_cast<int>(median);
        }

        std::cout << "done\nMedian according to STL function: " << std::setprecision( 12 ) << result << std::endl;

        // NOTE: Since the STL functions only sorted the array, but not the linked list, the 
        //       order of the linked list is still random and not pre-sorted.

        //calculate the median using the algorithm provided in this file
        std::cout << "Starting algorithm" << std::endl;
        result = find_median<int,double>( l.first );
        std::cout << "The calculated median is: " << std::setprecision( 12 ) << result << std::endl;

        std::cout << "Cleaning up\n\n" << std::flush;
    }
    catch ( std::bad_alloc )
    {
        std::cerr << "Error: Unable to allocate sufficient memory!" << std::endl;
        return -1;
    }

    return 0;
}

我已经用一百万个随机生成的元素成功地测试了我的代码,它几乎立即找到了正确的中位数。

【讨论】:

  • 请注意,我的原始答案中的代码有一个错误,我现在已经修复了。在编辑历史的第 17 版中,我添加了两行代码来修复错误。
【解决方案2】:

所以你可以做的是使用迭代器来保持位置。我已经编写了上面的算法来使用 std::forward_list。我知道这并不完美,但很快就写出来了,希望对您有所帮助。

    int partition(int leftPos, int rightPos, std::forward_list<int>::iterator& currIter, 
    std::forward_list<int>::iterator lowIter, std::forward_list<int>::iterator highIter) {
        auto iter = lowIter;
        int i = leftPos - 1;
        for(int j = leftPos; j < rightPos - 1; j++) {
           if(*iter <= *highIter) {
               ++currIter;
               ++i;
               std::iter_swap(currIter, iter);
           }
           iter++;
        }
        std::forward_list<int>::iterator newIter = currIter;
        std::iter_swap(++newIter, highIter);
        return i + 1;
    }

   std::forward_list<int>::iterator kthSmallest(std::forward_list<int>& list, 
   std::forward_list<int>::iterator left, std::forward_list<int>::iterator right, int size, int k) {
       int leftPos {0};
       int rightPos {size};
       int pivotPos {0};

       std::forward_list<int>::iterator resetIter = left;
       std::forward_list<int>::iterator currIter = left;
       ++left;
       while(leftPos <= rightPos) {
           pivotPos = partition(leftPos, rightPos, currIter, left, right);

           if(pivotPos == (k-1)) {
               return currIter;
           } else if(pivotPos > (k-1)) {
               right = currIter;
               rightPos = pivotPos - 1;
           } else {
               left = currIter;
               ++left;
               resetIter = left;
               ++left;
               leftPos = pivotPos + 1;
           }

           currIter = resetIter;
       }

       return list.end();
  }

当调用第 k 个迭代器时,左迭代器应该比你打算开始的地方少一个。这使我们在partition() 中落后low 一个位置。这是一个执行它的例子:

int main() {
    std::forward_list<int> list {10, 12, 12, 13, 4, 5, 8, 11, 6, 26, 15, 21};
    auto startIter = list.before_begin();
    int k = 6;
    int size = getSize(list);

    auto kthIter = kthSmallest(list, startIter, getEnd(list), size - 1, k);
    std::cout << k << "th smallest: " << *kthIter << std::endl;

    return 0;
}

第六小:10

【讨论】:

  • 当元素个数为偶数时,那么median通常定义为两个“中间”值的arithmetic mean
  • 我只是根据他使用std::forward_list 发布的链接为他提供了一个快速选择版本,您可以调用 kthSmallest 来获取两个“中间”值,然后使用它获得算术平均值.
  • 感谢您的回答。不幸的是,我的链表是一个“自定义”链表,它没有使用 stl 中的任何东西。例如。 push_back() 函数已定义,我没有使用 stl 等。也许我可以尝试调整发送到我的列表的代码...我实际上设法为链接列表设置了一种快速选择。问题是我需要更好地选择枢轴,因为当我设置它时,它会选择第一个元素作为枢轴......
  • 没问题。如果您可以根据列表的要求更新您的问题,我可以帮助改进此解决方案。如果您的列表使用 LegacyForwardIterator 并且您可以访问 begin() 这可以工作。只需将我的 forward_list 替换为您的自定义位置即可。绝对是一项有趣的任务。
  • @BrandonManning:我刚刚注意到按照您的建议两次调用您的函数kthSmallest 会将时间复杂度更改为 O(n^2)。这是因为您只需选择最后一个元素作为枢轴,而不是使用“三中位数规则”(我在实现中使用)。对于已排序的数据,这会导致 O(n^2) 的最坏情况时间复杂度。由于数据在第一次函数调用中被排序,第二次调用该函数时使用几乎相同的k 意味着该函数将在第二次函数调用中处理已经排序的数据。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2012-06-07
  • 1970-01-01
  • 1970-01-01
  • 2018-05-22
相关资源
最近更新 更多