我终于弄明白了。
从条件
保证对于每个节点 E,左子树在该节点 E 处的节点数与右子树一样多或多一个。
这样
- 可以从树的深度计算出非叶子节点的个数;它是 2深度 - 1。因此,有趣的事情是叶节点。
- 考虑到平衡条件,总是只有一个地方可以插入新节点或移除现有节点。 (这意味着给定数量的叶节点意味着一个,并且只有一个叶节点模式。)
- 如果我们知道一个节点的左子树的叶子节点数,我们就知道右子树中的叶子节点数(和节点数)要么相同,要么比它少一。李>
- 从 2. 和 3. 可以得出,右子树中只有一个叶节点槽,如果不检查树是否被填满,我们就无法知道它。找到它是这个算法的诀窍。
所以,利用3:假设我们有一个(子)树T。我们知道它的左子树中的叶子节点数是nleft。因此我们知道它的右子树中的叶子节点的数量是 nleft 或 nleft - 1,特别是它最多是 nleft .
我们进入右子树。知道了这棵子树的最大叶子节点数,并且知道它们在两边的子树中平均分配,我们可以推断出两件事:
- 如果此子树中的最大叶节点数为奇数,则有问题的槽位于左侧,因为右侧不能比左侧重。如果是偶数,则插槽在右侧
- 每个子子树中叶节点的最大数量为子树中叶节点的一半,左侧向上舍入,右侧向下舍入。
这解决了问题的核心;剩下的就是简单的递归。在 C++ 中:
#include <cstddef>
// I'm using a simple node structure, you'd use query functions. The
// algorithm is not meaningfully altered by this.
struct node {
node *left = nullptr, *right = nullptr;
};
struct node_counter {
std::size_t leaf; // number of leaf nodes,
std::size_t trunk; // number of trunk nodes,
std::size_t depth; // and depth of the inspected subtree.
};
// Interesting function #1: Given a right subtree and the leaf-count and
// depth of its left sibling, find the node that might or might not be there
node const *find_leaf(node const *branch, std::size_t leaf_count, std::size_t depth) {
// We've gone down, found the slot. Return it.
if(depth == 0) { return branch; }
// The heart of the matter: Step into the subtree that contains the
// questionable slot, with its maximum leaf node count and depth.
return find_leaf(leaf_count % 2 ? branch->left : branch->right,
(leaf_count + 1) / 2, // int division
depth - 1);
}
// Recursive counter. This steps down on the left side, then infers the
// number of leaf and trunk nodes on the right side for each level.
node_counter count_nodes_aux(node const *root) {
// leftmost leaf node is reached. Return info for it.
if(!root->left) {
return { 1, 0, 0 };
}
// We're in the middle of the tree. Get the counts for the left side,
auto ctr_left = count_nodes_aux(root->left);
// then find the questionable slot on the right
auto leaf_right = find_leaf(root->right, ctr_left.leaf, ctr_left.depth);
return {
// the number of leaf nodes in this tree is double that of the left
// subtree if the node is there, one less otherwise.
ctr_left.leaf * 2 - (leaf_right ? 0 : 1),
// And this is just an easy way to keep count of the number of non-leaf
// nodes and the depth of the inspected subtree.
ctr_left.trunk * 2 + 1,
ctr_left.depth + 1
};
}
// Frontend function to make the whole thing easily usable.
std::size_t count_nodes(node const *root) {
auto ctr = count_nodes_aux(root);
return ctr.leaf + ctr.trunk;
}
为了尝试这一点,我使用了以下非常丑陋的main 函数,它只是构建一个包含许多节点的树,在正确的位置插入新节点并检查计数器是否以正确的方式移动。它不漂亮,它不遵循最佳实践,如果你在生产环境中编写这样的代码,你应该被解雇。就是这样,因为这个答案的重点是上面的算法,我觉得把这个弄漂亮一点意义都没有。
void fill_node(node *n) {
n->left = new node;
n->right = new node;
}
int main() {
node *root = new node;
fill_node(root);
fill_node(root->left);
fill_node(root->right);
fill_node(root->left->left);
fill_node(root->left->right);
fill_node(root->right->left);
fill_node(root->right->right);
fill_node(root->left->left->left);
fill_node(root->left->left->right);
fill_node(root->left->right->left);
fill_node(root->left->right->right);
fill_node(root->right->left->left);
fill_node(root->right->left->right);
fill_node(root->right->right->left);
fill_node(root->right->right->right);
std::cout << count_nodes(root) << std::endl;
root->left ->left ->left ->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->left ->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->left ->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->left ->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->left ->right->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->right->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->right->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->right->left ->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->left ->left ->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->left ->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->left ->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->left ->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->left ->right->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->right->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->right->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->right->right->left = new node; std::cout << count_nodes(root) << std::endl;
root->left ->left ->left ->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->left ->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->left ->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->left ->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->left ->left ->right->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->right->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->right->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->right->left ->right = new node; std::cout << count_nodes(root) << std::endl;
root->left ->left ->left ->right->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->left ->right->right = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->left ->right->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->left ->right->right = new node; std::cout << count_nodes(root) << std::endl;
root->left ->left ->right->right->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->left ->right->right->right = new node; std::cout << count_nodes(root) << std::endl;
root->left ->right->right->right->right = new node; std::cout << count_nodes(root) << std::endl;
root->right->right->right->right->right = new node; std::cout << count_nodes(root) << std::endl;
}