树上差分

DFS 序 3,树上差分 1

给一棵有根树,这棵树由编号为\(1...N\)\(N\)个结点组成。根结点的编号为\(R\) 。每个结点都有一个权值,结点\(i\) 的权值为\(V-i\) 。 接下来有\(M\)组操作,操作分为三类:

  • 1 a b x,表示将「结点\(a\)到结点\(b\)的简单路径」上所有结点的权值都增加\(x\)
  • 2 a,表示求结点\(a\) 的权值。
  • 3 a,表示求 \(a\)的子树上所有结点的权值之和。

上来就写了树剖,交上去\(TLE\)了一个点。

诶,我被卡常数了?

卡常数卡到怀疑人生,点开最优解,树上差分啊,那没事了。

分析

先说前两个操作,考虑更改的链两个端点\(u,v\)祖先关系的情况。

\(v\)增加\(x\)\(fa[u]\)减去\(x\)即可。单点查询时查询该点的子树差分数组和。

更一般的情况是,\(u\)\(v\)增加\(x\)\(lca\)\(fa[lca]\)减去\(x\)

意会一下这样修改显然是对的。

对于三操作我们想一下\(u\)子树被修改的\(v\)的贡献。

每个点的贡献是\((dep[v]-dep[u]+1)\times val[v]\)

也就是\(dep[v]\times val[v]-(dep[u]-1)\times val[v]\)

所以开一个树状数组记一下\(dep[v]\times val[v]\),而\(\sum val[v]\)直接查一下所有子树修改的和就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int N = 1e6 + 5;

int n, m, r, v[N];
struct Bit {
ll t1[N], t2[N];
int lowbit(int x) { return x & -x; }
void add(int x, ll k) {
for (int i = x; i <= n; i += lowbit(i)) t1[i] += k, t2[i] += k * 1ll * (x - 1);
}
void upd(int l, int r, ll k) {
add(l, k);
add(r + 1, -k);
}
ll sum(int x) {
ll ans = 0;
for (int i = x; i; i -= lowbit(i)) ans += x * 1ll * t1[i] - t2[i];
return ans;
}
ll qry(int l, int r) { return sum(r) - sum(l - 1); }
} b1, b2;

vector<int> e[N];

int siz[N], son[N], top[N], fa[N], dfn[N], rnk[N], dep[N], dfs-clock;
void dfs1(int u, int f) {
dep[u] = dep[f] + 1;
siz[u] = 1;
fa[u] = f;
for (int v : e[u]) {
if (v == f) continue;
dfs1(v, u);
siz[u] += siz[v];
if (son[u] == 0 || siz[son[u]] < siz[v]) son[u] = v;
}
}
void dfs2(int u, int tp) {
dfn[u] = ++dfs-clock, rnk[dfs-clock] = u, top[u] = tp;
if (son[u]) dfs2(son[u], tp);
for (int v : e[u]) {
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
int lca(int u, int v) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
return v;
}
int read() {
int w = 0, f = 1;
char ch = getchar();
while (ch > '9' || ch < '0') {
if (ch == '-') f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
w = w * 10 + ch - 48;
ch = getchar();
}
return w * f;
}

int main() {
n = read(), m = read(), r = read();
for (int i = 1; i <= n; i++) v[i] = read();
for (int i = 1; i < n; i++) {
int u, v;
u = read(), v = read();
e[u].push-back(v);
e[v].push-back(u);
}
dfs1(r, 0), dfs2(r, r);
for (int i = 1; i <= n; i++) {
b1.upd(dfn[i], dfn[i], v[i]);
b2.upd(dfn[i], dfn[i], v[i] * 1ll * dep[i]);
if (i != r) b1.upd(dfn[fa[i]], dfn[fa[i]], -v[i]), b2.upd(dfn[fa[i]], dfn[fa[i]], -dep[fa[i]] * 1ll * v[i]);
}
for (int i = 1; i <= m; i++) {
int op, a, b, x;
op = read(), a =read();
if (op == 1) {
b = read(), x = read();
int LCA = lca(a, b);
b1.upd(dfn[b], dfn[b], x);
b1.upd(dfn[a], dfn[a], x);
b1.upd(dfn[LCA], dfn[LCA], -x);
b2.upd(dfn[b], dfn[b], dep[b] * 1ll * x);
b2.upd(dfn[a], dfn[a], dep[a] * 1ll * x);
b2.upd(dfn[LCA], dfn[LCA], -dep[LCA] * 1ll * x);
if (LCA != r) b1.upd(dfn[fa[LCA]], dfn[fa[LCA]], -x), b2.upd(dfn[fa[LCA]], dfn[fa[LCA]], -dep[fa[LCA]] * 1ll * x);
} else if (op == 2) printf("%lld\n", b1.qry(dfn[a], dfn[a] + siz[a] - 1));
else {
ll res = -b1.qry(dfn[a], dfn[a] + siz[a] - 1) * 1ll * (dep[a] - 1);
res += b2.qry(dfn[a], dfn[a] + siz[a] - 1);
printf("%lld\n", res);
}
}
return 0;
}

DFS 序 4

给一棵有根树,这棵树由编号为\(1...N\)\(N\)个结点组成。根结点的编号为\(R\) 。每个结点都有一个权值,结点\(i\) 的权值为\(V-i\) 。 接下来有\(M\)组操作,操作分为三类:

  • 1 a x,表示将结点\(a\) 的权值增加 ;
  • 2 a x,表示将\(a\) 的子树上所有结点的权值增加\(x\)
  • 3 a b,表示求「结点\(a\) 到结点 \(b\)的简单路径」上所有结点的权值之和。

一个重要的转化是\(val(u,v)=val(u)+val(v)-val(lca)-val(fa[lca])\)

考虑\(1,3\)问题,每个单点修改会贡献到子树上,单点修改直接修改,查询时分别查\(a,b\)到根的权值,减去\(lca\)\(fa[lca]\)的值即可。

对于\(2,3\)问题,\(u\)对子树节点\(v\)的贡献是\((dep[v]-dep[u]+1)\times val=dep[v]\times val+(1-dep[u])\times val\)

先把固定的\((1-dep[u])\times val\)给子树每个节点,另外子树的点记一下被加了多少次,查询的时候就乘上一个\(dep\),于是就知道了节点\(v\)到根的权值和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int N = 1e6 + 5;

int n, m, r;
struct Bit {
ll t1[N], t2[N];
int lowbit(int x) { return x & -x; }
void add(int x, ll k) {
for (int i = x; i <= n; i += lowbit(i)) t1[i] += k, t2[i] += k * 1ll * (x - 1);
}
void upd(int l, int r, ll k) {
add(l, k);
add(r + 1, -k);
}
ll sum(int x) {
ll ans = 0;
for (int i = x; i; i -= lowbit(i)) ans += x * 1ll * t1[i] - t2[i];
return ans;
}
ll qry(int l, int r) { return sum(r) - sum(l - 1); }
} b1, b2;

ll v[N];
vector<int> e[N];

int siz[N], son[N], top[N], fa[N], dfn[N], rnk[N], dep[N], dfs-clock;
void dfs1(int u, int f) {
dep[u] = dep[f] + 1;
siz[u] = 1;
fa[u] = f;
for (int v : e[u]) {
if (v == f) continue;
dfs1(v, u);
siz[u] += siz[v];
if (son[u] == 0 || siz[son[u]] < siz[v]) son[u] = v;
}
}
void dfs2(int u, int tp) {
dfn[u] = ++dfs-clock, rnk[dfs-clock] = u, top[u] = tp;
if (son[u]) dfs2(son[u], tp);
for (int v : e[u]) {
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
int lca(int u, int v) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
return v;
}
ll calc(int x) {
if (!x) return 0;
return b1.qry(dfn[x], dfn[x]) + b2.qry(dfn[x], dfn[x]) * 1ll * dep[x];
}

int main() {
scanf("%d %d %d", &n, &m, &r);
for (int i = 1; i <= n; i++) scanf("%lld", &v[i]);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
e[u].push-back(v);
e[v].push-back(u);
}
dfs1(r, 0), dfs2(r, r);
for (int i = 1; i <= n; i++) b1.upd(dfn[i], dfn[i] + siz[i] - 1, v[i]);
while (m--) {
int op, a, x;
scanf("%d %d %d", &op, &a, &x);
if (op == 1) {
b1.upd(dfn[a], dfn[a] + siz[a] - 1, x);
} else if (op == 2) {
b1.upd(dfn[a], dfn[a] + siz[a] - 1, (1 - dep[a]) * 1ll * x);
b2.upd(dfn[a], dfn[a] + siz[a] - 1, x);
} else if (op == 3) {
int LCA = lca(a, x);
printf("%lld\n", calc(a) + calc(x) - calc(LCA) - calc(fa[LCA]));
}
}
return 0;
}