K-D Tree
k-D Tree(KDT , k-Dimension Tree) 是一种可以 高效处理
在结点数
在算法竞赛的题目中,一般有
建树¶
k-D Tree 具有二叉搜索树的形态,二叉搜索树上的每个结点都对应
假设我们已经知道了
-
若当前超长方体中只有一个点,返回这个点。
-
选择一个维度,将当前超长方体按照这个维度分成两个超长方体。
-
选择切割点:在选择的维度上选择一个点,这一维度上的值小于这个点的归入一个超长方体(左子树),其余的归入另一个超长方体(右子树)。
-
将选择的点作为这棵子树的根节点,递归对分出的两个超长方体构建左右子树,维护子树的信息。
为了方便理解,我们举一个
其构建出 k-D Tree 的形态可能是这样的:
其中树上每个结点上的坐标是选择的分割点的坐标,非叶子结点旁的
这样的复杂度无法保证。对于
-
选择的维度要满足其内部点的分布的差异度最大,即每次选择的切割维度是方差最大的维度。
-
每次在维度上选择切割点时选择该维度上的 中位数,这样可以保证每次分成的左右子树大小尽量相等。
可以发现,使用优化
现在,构建 k-D Tree 时间复杂度的瓶颈在于快速选出一个维度上的中位数,并将在该维度上的值小于该中位数的置于中位数的左边,其余置于右边。如果每次都使用 sort
函数对该维度进行排序,时间复杂度是
我们来回顾一下快速排序的思想。每次我们选出一个数,将小于该数的置于该数的左边,大于该数的置于该数的右边,保证该数在排好序后正确的位置上,然后递归排序左侧和右侧的值。这样的期望复杂度是 algorithm
库中,有一个实现相同功能的函数 nth_element()
,要找到 s[l]
和 s[r]
之间的值按照排序规则 cmp
排序后在 s[mid]
位置上的值,并保证 s[mid]
左边的值小于 s[mid]
,右边的值大于 s[mid]
,只需写 nth_element(s+l,s+mid,s+r+1,cmp)
。
借助这种思想,构建 k-D Tree 时间复杂度是
插入/删除¶
如果维护的这个
我们引入一个重构常数
在插入一个
如果还有删除操作,则使用 惰性删除,即删除一个结点时打上删除标记,而保留其在 k-D Tree 上的位置。如果这样写,当未删除的结点数在以
类似于替罪羊树,带重构的 k-D Tree 的树高仍然是
邻域查询¶
例题luogu P1429 平面最近点对(加强版)
给定平面上的
首先建出关于这
枚举每个结点,对于每个结点找到不等于该结点且距离最小的点,即可求出答案。每次暴力遍历 2-D Tree 上的每个结点的时间复杂度是
此外,还可以使用一种启发式搜索的方法,即若一个结点的两个子树都有可能包含答案,先在与查询点距离最近的一个子树中搜索答案。可以认为,查询点到子树对应的长方形的最近距离就是此题的估价函数。
注意:虽然以上使用的种种优化,但是使用 k-D Tree 单次查询最近点的时间复杂度最坏还是
参考代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | #include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;
const int maxn = 200010;
int n, d[maxn], lc[maxn], rc[maxn];
double ans = 2e18;
struct node {
double x, y;
} s[maxn];
double L[maxn], R[maxn], D[maxn], U[maxn];
double dist(int a, int b) {
return (s[a].x - s[b].x) * (s[a].x - s[b].x) +
(s[a].y - s[b].y) * (s[a].y - s[b].y);
}
bool cmp1(node a, node b) { return a.x < b.x; }
bool cmp2(node a, node b) { return a.y < b.y; }
void maintain(int x) {
L[x] = R[x] = s[x].x;
D[x] = U[x] = s[x].y;
if (lc[x])
L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
if (rc[x])
L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}
int build(int l, int r) {
if (l >= r) return 0;
int mid = (l + r) >> 1;
double avx = 0, avy = 0, vax = 0, vay = 0; // average variance
for (int i = l; i <= r; i++) avx += s[i].x, avy += s[i].y;
avx /= (double)(r - l + 1);
avy /= (double)(r - l + 1);
for (int i = l; i <= r; i++)
vax += (s[i].x - avx) * (s[i].x - avx),
vay += (s[i].y - avy) * (s[i].y - avy);
if (vax >= vay)
d[mid] = 1, nth_element(s + l, s + mid, s + r + 1, cmp1);
else
d[mid] = 2, nth_element(s + l, s + mid, s + r + 1, cmp2);
lc[mid] = build(l, mid - 1), rc[mid] = build(mid + 1, r);
maintain(mid);
return mid;
}
double f(int a, int b) {
double ret = 0;
if (L[b] > s[a].x) ret += (L[b] - s[a].x) * (L[b] - s[a].x);
if (R[b] < s[a].x) ret += (s[a].x - R[b]) * (s[a].x - R[b]);
if (D[b] > s[a].y) ret += (D[b] - s[a].y) * (D[b] - s[a].y);
if (U[b] < s[a].y) ret += (s[a].y - U[b]) * (s[a].y - U[b]);
return ret;
}
void query(int l, int r, int x) {
if (l > r) return;
int mid = (l + r) >> 1;
if (mid != x) ans = min(ans, dist(x, mid));
if (l == r) return;
double distl = f(x, lc[mid]), distr = f(x, rc[mid]);
if (distl < ans && distr < ans) {
if (distl < distr) {
query(l, mid - 1, x);
if (distr < ans) query(mid + 1, r, x);
} else {
query(mid + 1, r, x);
if (distl < ans) query(l, mid - 1, x);
}
} else {
if (distl < ans) query(l, mid - 1, x);
if (distr < ans) query(mid + 1, r, x);
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%lf%lf", &s[i].x, &s[i].y);
build(1, n);
for (int i = 1; i <= n; i++) query(1, n, i);
printf("%.4lf\n", sqrt(ans));
return 0;
}
|
例题「CQOI2016」K 远点对
给定平面上的
和上一道例题类似,从最近点对变成了
由于题目中强调的是无序点对,即交换前后两点的顺序后仍是相同的点对,则每个有序点对会被计算两次,那么读入的
参考代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | #include <algorithm>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
const int maxn = 100010;
long long n, k;
priority_queue<long long, vector<long long>, greater<long long> > q;
struct node {
long long x, y;
} s[maxn];
bool cmp1(node a, node b) { return a.x < b.x; }
bool cmp2(node a, node b) { return a.y < b.y; }
long long lc[maxn], rc[maxn], L[maxn], R[maxn], D[maxn], U[maxn];
void maintain(int x) {
L[x] = R[x] = s[x].x;
D[x] = U[x] = s[x].y;
if (lc[x])
L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
if (rc[x])
L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}
int build(int l, int r) {
if (l > r) return 0;
int mid = (l + r) >> 1;
double av1 = 0, av2 = 0, va1 = 0, va2 = 0; // average variance
for (int i = l; i <= r; i++) av1 += s[i].x, av2 += s[i].y;
av1 /= (r - l + 1);
av2 /= (r - l + 1);
for (int i = l; i <= r; i++)
va1 += (av1 - s[i].x) * (av1 - s[i].x),
va2 += (av2 - s[i].y) * (av2 - s[i].y);
if (va1 > va2)
nth_element(s + l, s + mid, s + r + 1, cmp1);
else
nth_element(s + l, s + mid, s + r + 1, cmp2);
lc[mid] = build(l, mid - 1);
rc[mid] = build(mid + 1, r);
maintain(mid);
return mid;
}
long long sq(long long x) { return x * x; }
long long dist(int a, int b) {
return max(sq(s[a].x - L[b]), sq(s[a].x - R[b])) +
max(sq(s[a].y - D[b]), sq(s[a].y - U[b]));
}
void query(int l, int r, int x) {
if (l > r) return;
int mid = (l + r) >> 1;
long long t = sq(s[mid].x - s[x].x) + sq(s[mid].y - s[x].y);
if (t > q.top()) q.pop(), q.push(t);
long long distl = dist(x, lc[mid]), distr = dist(x, rc[mid]);
if (distl > q.top() && distr > q.top()) {
if (distl > distr) {
query(l, mid - 1, x);
if (distr > q.top()) query(mid + 1, r, x);
} else {
query(mid + 1, r, x);
if (distl > q.top()) query(l, mid - 1, x);
}
} else {
if (distl > q.top()) query(l, mid - 1, x);
if (distr > q.top()) query(mid + 1, r, x);
}
}
int main() {
cin >> n >> k;
k *= 2;
for (int i = 1; i <= k; i++) q.push(0);
for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;
build(1, n);
for (int i = 1; i <= n; i++) query(1, n, i);
cout << q.top() << endl;
return 0;
}
|
高维空间上的操作¶
例题luogu P4148 简单题
在一个初始值全为
1 x y A
:将坐标(x,y) A 2 x1 y1 x2 y2
:输出以(x1,y1) (x2,y2)
强制在线。内存限制 20M
。保证答案及所有过程量在 int
范围内。
20M 的空间卡掉了所有树套树,强制在线卡掉了 CDQ 分治,只能使用 k-D Tree。
构建 2-D Tree,支持两种操作:添加一个
在查询矩形区域内的所有点的权值和时,仍然需要记录子树内每一维度上的坐标的最大值和最小值。如果当前子树对应的矩形与所求矩形没有交点,则不继续搜索其子树;如果当前子树对应的矩形完全包含在所求矩形内,返回当前子树内所有点的权值和;否则,判断当前点是否在所求矩形内,更新答案并递归在左右子树中查找答案。
已经证明,如果在
参考代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | #include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 200010;
int n, op, xl, xr, yl, yr, lstans;
struct node {
int x, y, v;
} s[maxn];
bool cmp1(int a, int b) { return s[a].x < s[b].x; }
bool cmp2(int a, int b) { return s[a].y < s[b].y; }
double a = 0.725;
int rt, cur, d[maxn], lc[maxn], rc[maxn], L[maxn], R[maxn], D[maxn], U[maxn],
siz[maxn], sum[maxn];
int g[maxn], t;
void print(int x) {
if (!x) return;
print(lc[x]);
g[++t] = x;
print(rc[x]);
}
void maintain(int x) {
siz[x] = siz[lc[x]] + siz[rc[x]] + 1;
sum[x] = sum[lc[x]] + sum[rc[x]] + s[x].v;
L[x] = R[x] = s[x].x;
D[x] = U[x] = s[x].y;
if (lc[x])
L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
if (rc[x])
L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}
int build(int l, int r) {
if (l > r) return 0;
int mid = (l + r) >> 1;
double av1 = 0, av2 = 0, va1 = 0, va2 = 0;
for (int i = l; i <= r; i++) av1 += s[g[i]].x, av2 += s[g[i]].y;
av1 /= (r - l + 1);
av2 /= (r - l + 1);
for (int i = l; i <= r; i++)
va1 += (av1 - s[g[i]].x) * (av1 - s[g[i]].x),
va2 += (av2 - s[g[i]].y) * (av2 - s[g[i]].y);
if (va1 > va2)
nth_element(g + l, g + mid, g + r + 1, cmp1), d[g[mid]] = 1;
else
nth_element(g + l, g + mid, g + r + 1, cmp2), d[g[mid]] = 2;
lc[g[mid]] = build(l, mid - 1);
rc[g[mid]] = build(mid + 1, r);
maintain(g[mid]);
return g[mid];
}
void rebuild(int& x) {
t = 0;
print(x);
x = build(1, t);
}
bool bad(int x) { return a * siz[x] <= (double)max(siz[lc[x]], siz[rc[x]]); }
void insert(int& x, int v) {
if (!x) {
x = v;
maintain(x);
return;
}
if (d[x] == 1) {
if (s[v].x <= s[x].x)
insert(lc[x], v);
else
insert(rc[x], v);
} else {
if (s[v].y <= s[x].y)
insert(lc[x], v);
else
insert(rc[x], v);
}
maintain(x);
if (bad(x)) rebuild(x);
}
int query(int x) {
if (!x || xr < L[x] || xl > R[x] || yr < D[x] || yl > U[x]) return 0;
if (xl <= L[x] && R[x] <= xr && yl <= D[x] && U[x] <= yr) return sum[x];
int ret = 0;
if (xl <= s[x].x && s[x].x <= xr && yl <= s[x].y && s[x].y <= yr)
ret += s[x].v;
return query(lc[x]) + query(rc[x]) + ret;
}
int main() {
scanf("%d", &n);
while (~scanf("%d", &op)) {
if (op == 1) {
cur++, scanf("%d%d%d", &s[cur].x, &s[cur].y, &s[cur].v);
s[cur].x ^= lstans;
s[cur].y ^= lstans;
s[cur].v ^= lstans;
insert(rt, cur);
}
if (op == 2) {
scanf("%d%d%d%d", &xl, &yl, &xr, &yr);
xl ^= lstans;
yl ^= lstans;
xr ^= lstans;
yr ^= lstans;
printf("%d\n", lstans = query(rt));
}
if (op == 3) return 0;
}
}
|
习题¶
build本页面最近更新:,更新历史
edit发现错误?想一起完善? 在 GitHub 上编辑此页!
people本页面贡献者:hsfzLZH1, Ir1d
copyright本页面的全部内容在 CC BY-SA 4.0 和 SATA 协议之条款下提供,附加条款亦可能应用