LOJ2958「COCI 2009.10」ALADIN

类欧几里得算法例题

上午学了类欧,来做道例题练习一下.

题目

懒得说了,链接在上面...

题解

根据题意,每次做\(1\)操作的时候都是区间覆盖,所以原来的值和新的值没有关系.
假设现在覆盖了区间\(l,r\),怎么快速求出区间和呢?
\(n=l-r+1\),
即求
\(\ \ \ \ \sum\limits_{i=0}^ni\cdot A\bmod B\)
\(=\sum\limits_{i=0}^{n}i\cdot A-\left\lfloor\frac{i\cdot A}{B}\right\rfloor \cdot B\)
\(=A\cdot\sum\limits_{i=0}^ni-B\cdot\sum\limits_{i=0}^n\left\lfloor\frac{i\cdot A}{B}\right\rfloor\)

观察发现,左边就是一个等差数列,考虑右边怎么做.
也就是这个式子 $ f(a,b,c,n)={i=0}^n \(, 只不过是\)b=0$的情况.
$      f(a,b,c,n)=
{i=0}^n  $
\(=\sum\limits_{i=0}^n\left\lfloor \frac{\left(\left\lfloor\frac{a}{c}\right\rfloor c+a\bmod c\right)i+\left(\left\lfloor\frac{b}{c}\right\rfloor c+b\bmod c\right)}{c}\right\rfloor\)
\(=\frac{n(n+1)}{2}\left\lfloor\frac{a}{c}\right\rfloor+(n+1)\left\lfloor\frac{b}{c}\right\rfloor+ \sum\limits_{i=0}^n\left\lfloor\frac{\left(a\bmod c\right)i+\left(b\bmod c\right)}{c} \right\rfloor\) \(=\frac{n(n+1)}{2}\left\lfloor\frac{a}{c}\right\rfloor +(n+1)\left\lfloor\frac{b}{c}\right\rfloor+f(a\bmod c,b\bmod c,c,n)\)

现在只用考虑\(a<c,b<c\)的情况.
把式子变化一下, 因为想消去\(i\),所以增加一个变量, \(f(a,b,c,n)=\sum\limits_{i=0}^{n} \sum\limits_{j=0}^{\left\lfloor\frac{ai+b}{c} \right\rfloor -1} 1\)
交换求和符号.
\(\sum\limits_{j=0}^{\left\lfloor \frac{an+b}{c} \right\rfloor-1} \sum\limits_{i=0}^n\left[j<\left\lfloor \frac{ai+b}{c} \right\rfloor\right]\)

因为\(j<\left\lfloor \frac{ai+b}{c} \right\rfloor\Leftrightarrow j+1\leq \left\lfloor \frac{ai+b}{c} \right\rfloor\Leftrightarrow j+1\leq \frac{ai+b}{c}\Leftrightarrow cj+c\leq ai+b\Leftrightarrow \left\lfloor\frac{cj+c-b-1}{a}\right\rfloor<i\)

所以
\(\ \ \sum\limits_{j=0}^{\left\lfloor \frac{an+b}{c} \right\rfloor-1} \sum\limits_{i=0}^n\left[\left\lfloor\frac{cj+c-b-1}{a}\right\rfloor<i\right]\)
\(=\sum\limits_{j=0}^{\left\lfloor \frac{an+b}{c} \right\rfloor-1}n-\left\lfloor\frac{cj+c-b-1}{a}\right\rfloor\)
\(=\left\lfloor \frac{an+b}{c} \right\rfloor\times n-f(c,c-b-1,a,\left\lfloor \frac{an+b}{c} \right\rfloor-1)\)
发现\(a,c\)也恰好是交换的,是可以像欧几里得算法那样递归处理.

区间和能求出来了,但是询问的区间不一定就是修改的区间啊!
假设现在要查询\(l,r\)的和,我们只用知道其中的子区间是哪次覆盖产生的就可以了.
我们记录子区间的是第几次覆盖产生的,就可以知道这段子区间的值.
具体就是把覆盖的那次的整个区间的值用上面方法算出来,减去除开这段子区间的就可以了.
每次计算的时间复杂度都是\(\mathjaxcal{O}(logn)\)的.
问题来了,怎么知道到底有哪些子区间呢?
我的想法是用珂朵莉树来模拟每次的\(1\)操作,最后每个块就恰好就是所谓的子区间了.
每个块只需要记录一下左右端点和所对应的操作编号就行了.
当然也可以用线段树来做,其他题解说得很清楚了.

代码

话说是不是写得很丑啊...

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//CCCCCCCCCCCCCCCCCCCCCCCCOrz
#include<bits/stdc++.h>
using namespace std;

#define int --int128
const int Q = 50005;
//ODT
struct node {
int l, r;
mutable int id;
node(int L, int R = -1, int ID = -1) : l(L), r(R), id(ID) {}
inline bool operator < (const node &x) const {
return l < x.l;
}
};
struct data {
int op, l, r, a, b;
}a[Q];

#define IT set<node>::iterator
set<node> s;

IT split(int pos) {
IT it = s.lower-bound(node(pos));
if(it -> l == pos && it != s.end()) return it;
--it;
int v = it->id, l = it->l, r = it->r;
s.erase(it); s.insert(node(l, pos - 1, v));
return s.insert(node(pos, r, v)).first;
}
void assign(int l, int r, int id) {
IT itr = split(r + 1), itl = split(l);
s.erase(itl, itr);
s.insert(node(l, r, id));
}

int read() {
int w = 0, f = 1; char ch = getchar();
while(ch > '9' || ch < '0') {
if (ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
w = w * 10 + ch - 48;
ch = getchar();
}
return w * f;
}
void print(int x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x > 9) print(x / 10);
putchar(x % 10 + 48);
}

int calc(int n, int a, int b, int c) {
int ac = a / c, bc = b / c, m = (a * n + b) / c, n1 = n + 1, n2 = n * 2 + 1, d;
if (a == 0) {
d = bc * n1;
return d;
}
if (a >= c || b >= c) {
d = n * n1 / 2 * ac + bc * n1;
int e = calc(n, a % c, b % c, c);
d += e;
return d;
}
int e = calc(m - 1, c, c - b - 1, a);
d = n * m - e;
return d;
}
int val(int n, int a, int b) {
return (n + 1) * n / 2 * a - b * calc(n - 1, a, a, b);
}
IT find(int pos) {
IT it = s.lower-bound(node(pos));
if(it -> l == pos && it != s.end()) return it;
--it;
return it;
}
int sum(int l, int r, int id) {
if (!id) return 0;
int L = a[id].l;
return val(r - L + 1, a[id].a, a[id].b) - val(l - L, a[id].a, a[id].b);
}
void solve(int l, int r) {
int ans = 0;
IT itl = find(l), itr = find(r);
if (itl == itr) {
ans = sum(l, r, itl->id);
print(ans);puts("");
return;
}
ans += sum(l, itl->r, itl->id);
ans += sum(itr->l, r, itr->id);
itl++;
if (itl == itr) {
print(ans); puts("");
return;
}
itr--;
for (IT it = itl; ; ++it) {
if (it->id != 0) ans += sum(it->l, it->r, it->id);
if (it == itr) break;
}
print(ans);puts("");
return;
}

main() {
int n = read(), q = read();
s.insert(node(1, n, 0));
s.insert(node(n + 1, n + 1, -1));
for (int i = 1; i <= q; i++) {
a[i].op = read(), a[i].l = read(), a[i].r = read();
if (a[i].op == 1) a[i].a = read(), a[i].b = read();
}
for (int i = 1; i <= q; i++) {
if (a[i].op == 1) {
assign(a[i].l, a[i].r, i);
} else {
solve(a[i].l, a[i].r);
}
}
return 0;
}