【问题标题】:Efficient tree implementation in MATLABMATLAB 中的高效树实现
【发布时间】:2016-02-15 05:34:32
【问题描述】:

MATLAB 中的树类

我正在 MATLAB 中实现一个树形数据结构。向树中添加新的子节点、分配和更新与节点相关的数据值是我期望执行的典型操作。每个节点都有与之关联的相同类型的data。我不需要删除节点。到目前为止,我已经决定从 handle 类继承的类实现能够将节点的引用传递给将修改树的函数。

编辑:12 月 2 日

首先,感谢到目前为止 cmets 和答案中的所有建议。他们已经帮助我提高了我的树类。

有人建议尝试 R2015b 中引入的digraph。我还没有对此进行探索,但是看到它不像从handle 继承的类那样作为参考参数工作,我有点怀疑它在我的应用程序中的工作方式。在这一点上,我还不清楚使用自定义data 来处理节点和边有多容易。

编辑:(12 月 3 日)有关主应用程序的更多信息:MCTS

最初,我认为主应用程序的细节只会引起人们的兴趣,但自从阅读了 @FirefoxMetzger 的 cmets 和 answer 后,我意识到它具有重要意义。

我正在实现一种Monte Carlo tree search 算法。以迭代方式探索和扩展搜索树。维基百科提供了一个很好的过程图形概述:

在我的应用程序中,我执行了大量的搜索迭代。在每次搜索迭代中,我从根开始遍历当前树直到叶节点,然后通过添加新节点来扩展树,然后重复。由于该方法基于随机抽样,因此在每次迭代开始时,我不知道每次迭代将在哪个叶节点结束。相反,这是由树中当前节点的data 和随机样本的结果共同确定的。我在单次迭代期间访问的任何节点都会更新其data

示例:我在节点n 有几个孩子。我需要访问每个孩子的数据并随机抽取一个样本,以确定我在搜索中移动到下一个孩子。重复此过程,直到到达叶节点。实际上,我通过在根上调用 search 函数来决定接下来要扩展哪个子节点,在该节点上递归调用 search 等等,最后在到达叶节点时返回一个值。从递归函数返回时使用此值来更新搜索迭代期间访问的节点的data

这棵树可能非常不平衡,以至于一些分支是非常长的节点链,而其他分支在根级别之后很快终止并且不会进一步扩展。

当前实现

以下是我当前的实现示例,其中包含一些用于添加节点、查询树中节点的深度或数量等的成员函数示例。

classdef stree < handle
    %   A class for a tree object that acts like a reference
    %   parameter.
    %   The tree can be traversed in both directions by using the parent
    %   and children information.
    %   New nodes can be added to the tree. The object will automatically
    %   keep track of the number of nodes in the tree and increment the
    %   storage space as necessary.

    properties (SetAccess = private)
        % Hold the data at each node
        Node = { [] };
        % Index of the parent node. The root of the tree as a parent index
        % equal to 0.
        Parent = 0;
        num_nodes = 0;
        size_increment = 1;
        maxSize = 1;
    end

    methods
        function [obj, root_ID] = stree(data, init_siz)
            % New object with only root content, with specified initial
            % size
            obj.Node = repmat({ data },init_siz,1);
            obj.Parent = zeros(init_siz,1);
            root_ID = 1;
            obj.num_nodes = 1;
            obj.size_increment = init_siz;
            obj.maxSize = numel(obj.Parent);
        end

        function ID = addnode(obj, parent, data)
            % Add child node to specified parent
            if obj.num_nodes < obj.maxSize
                % still have room for data
                idx = obj.num_nodes + 1;
                obj.Node{idx} = data;
                obj.Parent(idx) = parent;
                obj.num_nodes = idx;
            else
                % all preallocated elements are in use, reserve more memory
                obj.Node = [
                    obj.Node
                    repmat({data},obj.size_increment,1)
                    ];

                obj.Parent = [
                    obj.Parent
                    parent
                    zeros(obj.size_increment-1,1)];
                obj.num_nodes = obj.num_nodes + 1;

                obj.maxSize = numel(obj.Parent);

            end
            ID = obj.num_nodes;
        end

        function content = get(obj, ID)
            %% GET  Return the contents of the given node IDs.
            content = [obj.Node{ID}];
        end

        function obj = set(obj, ID, content)
            %% SET  Set the content of given node ID and return the modifed tree.
            obj.Node{ID} = content;
        end

        function IDs = getchildren(obj, ID)
            % GETCHILDREN  Return the list of ID of the children of the given node ID.
            % The list is returned as a line vector.
            IDs = find( obj.Parent(1:obj.num_nodes) == ID );
            IDs = IDs';
        end
        function n = nnodes(obj)
            % NNODES  Return the number of nodes in the tree.
            % Equal to root + those whose parent is not root.
            n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
            assert( obj.num_nodes == n);
        end

        function flag = isleaf(obj, ID)
            % ISLEAF  Return true if given ID matches a leaf node.
            % A leaf node is a node that has no children.
            flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
        end

        function depth = depth(obj,ID)
            % DEPTH return depth of tree under ID. If ID is not given, use
            % root.
            if nargin == 1
                ID = 0;
            end
            if obj.isleaf(ID)
                depth = 0;
            else
                children = obj.getchildren(ID);
                NC = numel(children);
                d = 0; % Depth from here on out
                for k = 1:NC
                    d = max(d, obj.depth(children(k)));
                end
                depth = 1 + d;
            end
        end
    end
end

但是,有时性能很慢,树上的操作占用了我大部分的计算时间。有哪些具体的方法可以提高实施效率?如果有性能提升,甚至可以将实现更改为 handle 继承类型以外的其他东西。

使用当前实现分析结果

因为向树中添加新节点是最典型的操作(以及更新节点的data),所以我在上面做了一些profiling。 我使用Nd=6, Ns=10 在以下基准代码上运行分析器。

function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
    function add_layers(node_id, num_layers)
        if num_layers == 0
            return;
        end
        child_id = zeros(Ns,1);
        for s = 1:Ns
            % add child to current node
            child_id(s) = T.addnode(node_id, rand);

            % recursively increase depth under child_id(s)
            add_layers(child_id(s), num_layers-1);
        end
    end
end

分析器的结果:

R2015b 性能


已发现 R2015b improves the performance of MATLAB's OOP features。我重新进行了上述基准测试,确实观察到了性能的提升:

所以这已经是个好消息了,尽管我们当然可以接受进一步的改进;)

以不同方式保留内存

cmets中也有人建议使用

obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];

保留更多内存,而不是使用repmat 的当前方法。这略微提高了性能。我应该注意我的基准代码是针对虚拟数据的,因为实际的data 更复杂,这可能会有所帮助。谢谢!探查器结果如下:

关于进一步提高性能的问题

  1. 也许有另一种更有效的方式来维护树的内存?遗憾的是,我通常不会提前知道树中有多少个节点。
  2. 添加新节点和修改现有节点的data 是我在树上做的最典型的操作。截至目前,它们实际上占用了我主要应用程序的大部分处理时间。欢迎对这些功能进行任何改进。

最后一点,理想情况下,我希望将实现保持为纯 MATLAB。但是,可以接受 MEX 等选项或使用一些集成的 Java 功能。

【问题讨论】:

  • 运行profiler 可以在性能方面在您的代码中阐明很多。运行一次,看看代码哪里特别慢,它会给你一个从哪里开始改进的指针。
  • Matlab OOP adds a significant overhead unless you use Matlab 2015b or newer 这可能会导致问题。不使用handle 可能无济于事。
  • @Adriaan 感谢您的建议。我添加了一些分析器数据。
  • 另外,根据您的数据是什么,使用repmat 分配节点数据可能会增加很多开销。为什么不用obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)]; 初始化?
  • 当您修改数据时,您是如何访问节点的?更具体地说,为什么要保存节点的父节点而不是其子节点?从外观上看,使用单个查找表或结构来存储数据可能会更快。

标签: performance matlab data-structures tree


【解决方案1】:

我知道这听起来可能很愚蠢...但是保留空闲节点的数量而不是节点的总数怎么样?这需要与一个常量(为零)进行比较,这是单一属性访问。

另一个巫术改进是将.maxSize 移动到.num_nodes 附近,并将它们放在.Node 单元之前之前。像这样,由于.Node 属性的增长,它们在内存中的位置不会相对于对象的开头发生变化(这里的巫术是我在猜测 MATLAB 中对象的内部实现)。

稍后编辑当我将.Node 移到属性列表末尾进行分析时,大部分执行时间都被扩展.Node 属性所消耗,正如预期的那样(5.45 秒,而您提到的比较需要 1.25 秒)。

【讨论】:

  • 有趣!与常数进行比较似乎是个好主意,我确实观察到了轻微的改进。虽然不确定.Node 属性 - 我没有观察到切换属性位置的改进。我的理解是,元胞数组元素在内存中不一定占据连续的块,所以对性能的影响不好预测。
  • @mikkola 问题是,单元格数组的元素不在连续的内存区域中,但它们的——我们称之为——地址必须是。从概念上讲,单元数组就像一个指针数组,并且该数组本身在一个连续的内存区域中增长。顺便问一下,你每次运行基准测试时都会清除你的课程吗?这通常会降低新实例化对象的性能(JIT 的东西,加上元类需要再次重新创建)
  • 感谢您对地址指针的澄清!我每次都重复添加clear allclear classes 的基准测试。尽管我同意您的建议是有道理的,但我仍然只看到性能没有显着改善。
  • @mikkola 很抱歉我的建议没有带来显着的效果。同样,当实现不透明时,所有人所能做的就是尝试几种可能有意义的理论(这就是我所说的巫毒编程)。最终,如果一个人不断冲击系统,它可能会找到它的最佳位置,或者如果不等到下一个可以提高性能的版本。
【解决方案2】:

您可以尝试分配与实际填充的元素数量成正比的元素数量:这是 c++ 中 std::vector 的标准实现

obj.Node = [obj.Node; data; cell(q * obj.num_nodes,1)];

我记不清了,但在 MSCC 中,q 是 1,而 GCC 是 0.75。


这是一个使用 Java 的解决方案。我不太喜欢它,但它确实发挥了作用。我实现了您从维基百科中提取的示例。

import javax.swing.tree.DefaultMutableTreeNode

% Let's create our example tree
top = DefaultMutableTreeNode([11,21])
n1 = DefaultMutableTreeNode([7,10])
top.add(n1)
n2 = DefaultMutableTreeNode([2,4])
n1.add(n2)
n2 = DefaultMutableTreeNode([5,6])
n1.add(n2)
n3 = DefaultMutableTreeNode([2,3])
n2.add(n3)
n3 = DefaultMutableTreeNode([3,3])
n2.add(n3)
n1 = DefaultMutableTreeNode([4,8])
top.add(n1)
n2 = DefaultMutableTreeNode([1,2])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n1 = DefaultMutableTreeNode([0,3])
top.add(n1)

% Element to look for, your implementation will be recursive
searching = [0 1 1];
idx = 1;
node(idx) = top;
for item = searching,
    % Java transposes the matrices, remember to transpose back when you are reading
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

% We made a new test...
newdata = [0, 1]
newnode = DefaultMutableTreeNode(newdata)
% ...so we expand our tree at the last node we searched
node(idx).add(newnode)

% The change has to be propagated (this is where your recursion returns)
for it=length(node):-1:1,
    itnode=node(it);
    val = itnode.getUserObject()'
    newitemdata = val + newdata
    itnode.setUserObject(newitemdata)
end

% Let's see if the new values are correct
searching = [0 1 1 0];
idx = 1;
node(idx) = top;
for item = searching,
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

【讨论】:

  • 这将是我的建议。每次向量用完空间时,它都会保留 2x。计算机喜欢 2 的幂 :)
  • 感谢您的回答。澄清:searching = [0 1 1]; 表示我们想通过以下策略找到一个节点:root 的第一个孩子,该节点的第二个孩子,最后以该节点的第二个孩子结束 - 对吗?快速测试似乎表明 UserObject 可能是任何 Matlab 支持的类型,希望这也是正确的。
  • searching 数组的解释是正确的(数组在 Java 中是从 0 开始的)。您可以存储在 Java 数组中的对象(如句柄)是有限制的,您必须尝试使用​​您的有效对象。
【解决方案3】:

TL:DR您深度复制存储在每次插入中的整个数据,将parentNode 单元格初始化为比您预期需要的更大。

您的数据确实具有树结构,但是您没有在实现中使用它。相反,实现的代码是查找表(实际上是 2 个表)的计算量大版本,用于存储树的数据和关系数据。

我这么说的原因如下:

  • 要插入你调用stree.addnote(parent, data),它将所有数据存储在树对象stree的字段Node = {}Parent = []
  • 您似乎事先知道要访问树中的哪个元素,因为没有给出搜索代码(如果您使用 stree.getchild(ID),我有一些坏消息)
  • 一旦你处理了一个节点,你就可以使用find() 追踪它,这是一个列表搜索

这绝不意味着实现对于数据来说是笨拙的,它甚至可能是最好的,这取决于你在做什么。但是,它确实解释了您的内存分配问题并提供了有关如何解决这些问题的提示。


将数据保存为查找表

存储数据的方法之一是保留底层查找表。我只会这样做,如果您知道要修改的第一个元素的ID而不搜索它。这种情况下,您可以分两步使您的结构更高效。

首先初始化你的数组更大然后你期望你需要存储数据。如果超过了查找表的容量,则初始化一个新的,其中 X 字段更大,并对旧数据进行深拷贝。如果您需要一次或两次扩展容量(在所有插入期间),这可能不是问题,但在您的情况下,为永远插入制作一个深层副本!

其次,我将更改内部结构并将NodeParent 这两个表合并。原因是代码中的反向传播需要 O(depth_from_root * n),其中 n 是表中的节点数。这是因为 find() 将为每个父级遍历整个表。

相反,您可以实现类似于

table = cell(n,1) % n bigger then expected value
end_pointer = 1 % simple pointer to the first free value

function insert(data,parent_ID)
    if end_pointer < numel(table)
        content.data = data;
        content.parent = parent_ID;
        table{end_pointer} = content;
        end_pointer = end_pointer + 1;
    else
        % need more space, make sure its enough this time
        table = [table cell(end_pointer,1)];
        insert(data,parent_ID);
    end
end

function content = get_value(ID)
    content = table(ID);
end

这立即使您可以访问父级的ID,而无需首先访问find(),每一步都节省了n 次迭代,因此负担变为O(深度)。如果你不知道你的初始节点,那么你必须find()那个,这需要花费 O(n)。

请注意,此结构不需要is_leaf()depth()nnodes()get_children()。如果您仍然需要这些,我需要更深入地了解您希望如何处理您的数据,因为这会极大地影响正确的结构。


树形结构

如果您永远不知道第一个节点的 ID 并且因此总是必须搜索它,那么这种结构是有意义的。

好处是搜索任意音符的时间为 O(depth),因此搜索是 O(depth) 而不是 O(n),反向传播是 O(depth^2) 而不是 O(depth + n )。请注意,深度可以是任何值,从用于完美平衡树的 log(n)(取决于您的数据)到用于退化树的 n(它只是一个链表)。

但是,我需要更多的洞察力来提出正确的建议,因为每种树结构都有其自己的利基。从我目前所看到的情况来看,我建议使用不平衡的树,它按照想要的父节点给出的简单顺序“排序”。这可能会进一步优化,具体取决于

  • 是否可以根据您的数据定义总顺序
  • 如何处理双精度值(相同的数据出现两次)
  • 您的数据规模有多大(数千、数百万……)
  • 总是与反向传播配对的查找/搜索
  • 您的数据上的“父子”链有多长(或者使用这个简单的顺序树的平衡和深度)
  • 总是只有一个父元素,还是同一元素被不同的父元素插入两次

我很乐意为上面的树提供示例代码,请给我留言。

编辑: 在您的情况下,不平衡树(与执行 MCTS 平行)似乎是最佳选择。下面的代码假定数据在statescore 中拆分,并且state 是唯一的。如果不是这样,这仍然有效,但是可能会进行优化以提高 MCTS 性能。

classdef node < handle
    % A node for a tree in a MCTS
    properties
        state = {}; %some state of the search space that identifies the node
        score = 0;
        childs = cell(50,1);
        num_childs = 0;
    end
    methods
        function obj = node(state)
            % for a new node simulate a score using MC
            obj.score = simulate_from(state); % TODO implement simulation state -> finish
            obj.state = state;
        end
        function value = update(obj)
            % update the this node using MC recursively
            if obj.num_childs == numel(obj.childs)
                % there are to many childs, we have to expand the table
                obj.childs = [obj.childs cell(obj.num_childs,1)];
            end
            if obj.do_exploration() || obj.num_childs == 0
                % explore a potential state
                state_to_explore = obj.explore();

                %check if state has already been visited
                terminate = false;
                idx = 1;
                while idx <= obj.num_childs && ~terminate
                    if obj.childs{idx}.state_equals(state_to_explore)
                        terminate = true;
                    end
                    idx = idx + 1;
                end

                %preform the according action based on search
                if idx > obj.num_childs
                    % state has never been visited
                    % this action terminates the update recursion 
                    % and creates a new leaf
                    obj.num_childs = obj.num_childs + 1;
                    obj.childs{obj.num_childs} = node(state_to_explore);
                    value = obj.childs{obj.num_childs}.calculate_value();
                    obj.update_score(value);
                else
                    % state has been visited at least once
                    value = obj.childs{idx}.update();
                    obj.update_score(value);
                end
            else
                % exploit what we know already
                best_idx = 1;
                for idx = 1:obj.num_childs
                    if obj.childs{idx}.score > obj.childs{best_idx}.score
                        best_idx = idx;
                    end
                end
                value = obj.childs{best_idx}.update();
                obj.update_score(value);
            end
            value = obj.calculate_value();
        end
        function state = explore(obj)
            %select a next state to explore, that may or may not be visited
            %TODO
        end
        function bool = do_exploration(obj)
            % decide if this node should be explored or exploited
            %TODO
        end
        function bool = state_equals(obj, test_state)
            % returns true if the nodes state is equal to test_state
            %TODO
        end
        function update_score(obj, value)
            % updates the score based on some value
            %TODO
        end
        function calculate_value(obj)
            % returns the value of this node to update previous nodes
            %TODO
        end
    end
end

代码上的几个cmets:

  • 根据设置,可能不需要obj.calculate_value()。例如。如果它是一个可以通过单独评估孩子的分数来计算的值
  • 如果 state 可以有多个父对象,那么重用 note 对象并将其覆盖在结构中是有意义的
  • 由于每个node 都知道其所有子节点,因此可以使用node 作为根节点轻松生成子树
  • 搜索树(没有任何更新)是一个简单的递归贪心搜索
  • 根据搜索的分支因素,可能值得访问每个可能的子节点一次(在节点初始化时),然后再进行randsample(obj.childs,1) 进行探索,因为这样可以避免复制/重新分配子数组
  • parent 属性在树递归更新时进行编码,在完成节点更新后将 value 传递给父节点
  • 我唯一一次重新分配内存是当单个节点有超过 50 个子节点时,我只为该单个节点重新分配

这应该运行得更快,因为它只担心选择树的任何部分而不涉及其他任何部分。

【讨论】:

  • 感谢您的回复!我的应用程序是一种蒙特卡洛树搜索。我最初并没有意识到它会对设计树产生多大影响,对此感到抱歉!我更新了问题以包含更多详细信息。我希望它可以帮助您更好地集中答案。
  • 另一个后续:对于特定的应用程序和我遍历树的方式,您认为查找表实现实际上更合适吗?我以前没有考虑过。我在 MCTS 上阅读的树搜索方面、可视化和其他文档吸引我直接使用树数据结构来实现它。
  • 我是否正确假设您的“数据”由一些奇特的状态+该状态的分值组成?而且更新不会影响状态,而是会改变分数?
  • 没错。数据首先包含一个标识符(通常是整数),它将节点链接到在底层优化任务中有意义的东西,但也需要知道 history,即节点的标识符通过达到了这个节点,总体来说是有意义的。此外,还有一个分数(浮点数)和一个访问次数(正整数)。更新仅更改分数和计数。
  • 要继续查找表的想法,另一种方法可能是为每个节点保存此历史记录,从而无需历史记录。
猜你喜欢
  • 1970-01-01
  • 2017-09-17
  • 2011-01-06
  • 2013-12-17
  • 2017-03-21
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2013-03-08
相关资源
最近更新 更多