P3794 签到题IV

题目: luogu 签到题IV
话说这道题和UVA1642 魔法GCD很类似 # Solution

题目要求找所有满足条件的子序列,可以考虑固定右端点\(j\),然后快速求出每个左端点\(i\)对应的子序列的有关信息。
好像有点抽象,可以想一下这个序列怎么做。
\[3,4,6,4,2\] \(j=4\)时,我们写一下所有子序列的有关值,即\(gcd(a-i,a_{i+1}...a-j)\)\((a-i\ or\ a_{i+1}...\ or\ a-j)\) 1. \(i=1\)时,\(gcd(a-1,a-2,a-3,a-4)=1\), \((a-1\ or\ a-2\ or\ a-3\ or\ a-4)=7\) 2. \(i=2\)时,\(gcd(a-2,a-3,a-4)=2\), \((a-2\ or\ a-3\ or\ a-4)=6\) 3. \(i=3\)时,\(gcd(a-3,a-4)=2\), \((a-3\ or\ a-4)=6\) 4. \(i=4\)时,\(gcd(a-4)=4\), \((a-4 \ or\ a-4)=4\)

很显然我们可以把每个左端点\(i\)对应的有关信息存下来,然后一个个判断是否符合条件。每次\(j\)右移时(即增加一个元素)都把每个左端点维护的信息更新。这样做是\(O(n^2 logn)\)的。
观察上表,可以发现有的\(gcd\)或者二进制或的和是相等的。事实上每加入一个新的数,序列的\(gcd\)是不变或者变小的,且变小时会变为\(a-j\)的约数。而序列的二进制或则是不变或者变大的,如果把数写成二进制的形式,可以发现如果增大只会在某一位增加\(1\),而一共只有\(\log-2 a-j\)位。所以\(gcd\)或者二进制或的值的个数都是\(log\)级别的。
那么如果把相等的值合并,复杂度可以降为\(O(nlog^2 n)\)的。
因为\(a\oplus b=k\),有\(a\oplus k=b\)。那么可以分别维护\(gcd(a-i,a_{i+1}...a-j)\)\((a-i\ or\ a_{i+1}...\ or\ a-j)\)。计算出每个\(gcd\oplus k\),之后去找有没有等于\(gcd\oplus k\)\(or\)值。
怎么维护呢?可以发现无论是\(gcd\)还是\(or\)的值都是单调的,受到\(ODT\)的启发,可以把相同的值合并为一个块,也就是记录它的左右端点和值,每次查询都\(O(log)\)的算出每个\(gcd\oplus k\)。然后二分查找或者直接枚举有没有与它相等的\(or\)值,如果有,就把两个区间(该\(gcd\)对应的区间和\(or\)对应的区间)取交集,答案累加上此时的元素个数就可以。时间复杂度\(O(n log^2 n)\),实际上很难有每次都使表里有\(\log a-j\)个元素的数据,所以复杂度和\(O(nlog n)\)接近。

不知道为什么二分还没直接枚举跑得快QwQ

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5 + 10;
int n, k, a[N];
struct node {
int l, r, v;
};
vector<node> v, o;

int read() {
int w = 0; char ch = getchar();
while(ch > '9' || ch < '0') ch = getchar();
while(ch >= '0' && ch <= '9') {
w = w * 10 + ch - 48;
ch = getchar();
}
return w;
}
int gcd(int a, int b) {
return b == 0 ? a : gcd(b, a % b);
}
node bound(int x) {
for(int i = 0; i < o.size(); i++) {
if(o[i].v == x) return {o[i].l, o[i].r, x};
}
return {0, 0, -1};
}

int main() {
long long sum = 0;
n = read(); k = read();
for(int i = 1; i <= n; i++) a[i] = read();
for(int j = 1; j <= n; j++) {
v.push-back({j, j, a[j]});
o.push-back({j, j, a[j]});
for(int i = v.size() - 2; i >= 0; i--) {
v[i].v = gcd(v[i].v, a[j]);
if(v[i].v == v[i + 1].v) {
v[i].r = v[i + 1].r;
v.erase(v.begin() + i + 1);
}
}
for(int i = o.size() - 2; i >= 0; i--) {
o[i].v |= a[j];
if(o[i].v == o[i + 1].v) {
o[i].r = o[i + 1].r;
o.erase(o.begin() + i + 1);
}
}
for(int i = 0; i < v.size(); i++) {
node ans = bound(v[i].v ^ k);
if(ans.v == -1 || ans.l > v[i].r || ans.r < v[i].l) continue;
ans.l = max(v[i].l, ans.l); ans.r = min(ans.r, v[i].r);
sum += ans.r - ans.l + 1;
}
}
printf("%lld\n",sum);
return 0;
}