题目描述

[EN | CN]

给定一个二叉树(具有根结点 root), 一个目标结点 target,和一个整数值 K

返回到目标结点 target 距离为 K 的所有结点的值的列表。答案可以以任何顺序返回。

示例 1:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
3
   /   \
  5     1
 / \   / \
6   2 0   8
   / \
  7   4

输入:root = [3,5,1,6,2,0,8,null,null,7,4], target = 5, K = 2
输出:[7,4,1]
解释:
所求结点为与目标结点(值为 5)距离为 2 的结点,值分别为 7,4,以及 1。
注意:
输入的 "root" 和 "target" 实际上是树上的结点。上面的输入仅仅是对这些对象进行了序列化描述。

提示:

  • 给定的树是非空的,且最多有 K 个结点。
  • 树上的每个结点都具有唯一的值 0 <= node.val <= 500
  • 目标结点 target 是树上的结点。
  • 0 <= K <= 1000

解法 1:沿路径向 root 延伸

这个解法想法很简单,写起来有点麻烦。根据题意,满足的结点可能有以下几种情况:

  1. 位于 target 结点子树中;
  2. 位于从 roottarget 的这条路径(path)上;
  3. 位于 path 中的结点的邻居结点的子树中。

因此思路其实比较直接,就是遍历计算这三种情况。

复杂度分析略。

实现与结果如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution:
    def distanceK(self, root: TreeNode, target: TreeNode, K: int) -> List[int]:
        # Find the path to target
        def path_to_root(root, target):
            path, is_left, curr = [], [], root
            while True:
                if curr:
                    path.append(curr)
                    if curr == target:
                        break
                    else:
                        is_left.append(True)
                        curr = curr.left
                else:
                    if is_left[-1]:
                        is_left[-1] = False
                        curr = path[-1].right
                    else:
                        path.pop()
                        is_left.pop()
            return path[::-1]

        # Subnodes with distance k to root
        def sub_distance_k(root, k):
            if k < 0:
                return []
            curr_level = [root]
            next_level = []
            while k > 0:
                for node in curr_level:
                    if node:
                        next_level.extend([node.left, node.right])
                curr_level, next_level = next_level, []
                k -= 1
            return [node for node in curr_level if node]

        path = path_to_root(root, target)

        # Nodes on the path
        ret = sub_distance_k(target, K)
        if K > 0 and len(path) > K:
            ret.append(path[K])

        # Nodes under the neighbors of nodes on the path
        for i in range(1, len(path)):
            if path[i - 1] == path[i].left:
                neighbor = path[i].right
            else:
                neighbor = path[i].left
            ret.extend(sub_distance_k(neighbor, K - i - 1))

        return [node.val for node in ret if node]
  • 执行用时:48 ms,在所有 Python3 提交中击败了 57.14% 的用户。
  • 内存消耗:13.5 MB,在所有 Python3 提交中击败了 100.00% 的用户。

解法 2:反向指针

有另一种做法更直接简单,实现起来也更方便。我们为每个结点记录其父结点(例如使用 DFS、BFS 等),从而获得一个有向图(父结点指向子结点,子结点指向父结点),随后我们就可以从 target 开始进行 BFS,从而获得离 target 距离为 K 的所有结点。

下面的实现用了一个 src_to_dst 来记录所有相邻的子结点或者父结点,所以不需要修改原数据。事实上,在 Python 中可以直接对原来的对象进行赋值(新的属性),因此可以更方便快速地实现,缺点就是会修改输入数据,因此这里还是不采用这种直接给输入对象赋值的做法。

复杂度分析略。

实现与结果如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

from collections import defaultdict

class Solution:
    def distanceK(self, root: TreeNode, target: TreeNode, K: int) -> List[int]:
        src_to_dst = defaultdict(lambda: [])
        nodes = [root]
        while nodes:
            node = nodes.pop()
            if node.left:
                src_to_dst[node].append(node.left)
                src_to_dst[node.left].append(node)
                nodes.append(node.left)
            if node.right:
                src_to_dst[node].append(node.right)
                src_to_dst[node.right].append(node)
                nodes.append(node.right)

        curr_dist, next_dist, visited = [], [target], []
        for _ in range(K):
            curr_dist, next_dist = next_dist, []
            for node in curr_dist:
                visited.append(node)
                for dst in src_to_dst[node]:
                    if dst not in visited:
                        next_dist.append(dst)

        return [node.val for node in next_dist]
  • 执行用时:56 ms,在所有 Python3 提交中击败了 30.79% 的用户。
  • 内存消耗:13.8 MB,在所有 Python3 提交中击败了 100.00% 的用户。