题解:[SCOI2010] 序列操作

洛谷P2572

Posted by TH911 on December 21, 2024

题目传送门

题意分析

我们需要维护一种数据结构来维护 $0,1$ 区间的信息,使其能够高效的获取区间内 $1$ 的数量和最多的连续的 $1$ 的数量,并能够高效的将指定区间赋为指定值和实现区间翻转(注意是值取反,而不是类似于字符串的翻转)。

维护区间信息

线段树节点维护信息

维护区间信息,自然而然地想到线段树

那么需要维护哪些信息呢?

一般地,区间边界 $l,r$。

对于操作 $3$,显然我们需要维护区间内 $1$ 的数量。

考虑到需要进行“区间翻转”,因此我们也要维护 $0$ 的信息(见下文)。

维护 $cnt_0,cnt_1$ 表示区间内 $0,1$ 分别出现的次数,当然,你可以使用 $(r-l+1)-cnt_1$ 来获取 $cnt_0$,但是为了方便和美观,我们不这么做。

维护两个懒标记 $set$ 和 $reverse$,$set$ 表示区间是否被赋值,初始值为 $-1$;$reverse$ 是一个 bool 变量,用于标记是否翻转。

维护 $len_1$ 表示区间内最多的连续的 $1$ 的数量,考虑到需要翻转,因此也维护 $len_0$。

同时还需要维护 $pre_0,pre_1,suf_0,suf_1$,具体见下文。

参考代码

1
2
3
4
5
6
struct node{
    int l,r;
    int cnt[2],len[2],pre[2],suf[2];
    int set;
    bool reverse;
}t[4*N+1];

区间合并与上传更新

这是一定需要考虑的,在 $up(p)$ 中合并区间(上传更新),$cnt_0,cnt_1$ 都好合并,直接相加即可;然而 $len_0,len_1$ 却不好合并,因为我们需要考虑跨越中线的情况。

比如说,合并 $\texttt{1101}$ 和 $\texttt{1100}$。如果直接取最大值,就会发现合并后新的 $len_1=2$。然而合并后为 $\texttt{11011100}$,明显答案为 $3$;这就是所谓“跨越中线”。

我们可以效仿平衡树维护区间信息,维护 $pre_0,pre_1,suf_0,suf_1$ 分别表示最长前缀 $0$ 的长度、最长前缀 $1$ 的长度、最长后缀 $0$ 的长度、最长后缀 $1$ 的长度。

那么,对于 $pre_0$ 的转移明显就是 $t[p].pre_0\leftarrow t[p\times 2].pre_0$。

一个特殊的情况就是,$t[p\times 2]$ 所维护的区间全是 $0$。这样,还要加上 $t[p\times 2+1].pre_0$。

对于剩下三个,同理可得。

此时我们再来维护 $len_0,len_1$。

以 $len_0$ 为例,仅仅需要用原来取的最大值再与 $t[p\times 2].suf_0+t[p\times 2+1].pre_0$ 取最大值即可。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void up(int p){
    t[p].cnt[0]=t[p<<1].cnt[0]+t[p<<1|1].cnt[0];
    t[p].cnt[1]=t[p<<1].cnt[1]+t[p<<1|1].cnt[1];
    t[p].len[0]=max(t[p<<1].len[0],t[p<<1|1].len[0]);
    t[p].len[0]=max(t[p].len[0],t[p<<1].suf[0]+t[p<<1|1].pre[0]);
    t[p].len[1]=max(t[p<<1].len[1],t[p<<1|1].len[1]);
    t[p].len[1]=max(t[p].len[1],t[p<<1].suf[1]+t[p<<1|1].pre[1]);

    t[p].suf[0]=t[p<<1|1].suf[0];
    if(t[p<<1|1].suf[0]==size(p<<1|1))t[p].suf[0]+=t[p<<1].suf[0];
    t[p].suf[1]=t[p<<1|1].suf[1];
    if(t[p<<1|1].suf[1]==size(p<<1|1))t[p].suf[1]+=t[p<<1].suf[1];

    t[p].pre[0]=t[p<<1].pre[0];
    if(t[p<<1].pre[0]==size(p<<1))t[p].pre[0]+=t[p<<1|1].pre[0];
    t[p].pre[1]=t[p<<1].pre[1];
    if(t[p<<1].pre[1]==size(p<<1))t[p].pre[1]+=t[p<<1|1].pre[1];
}

懒标记与下传更新

上文已经提到,设计了一个 $set$ 和一个 $reverse$,现在展开详细解释。

我们定义 $down(p)$ 来下传更新(懒标记)。

那么需要传递两种操作:区间推平和区间翻转。

为了方便,我们先考虑推平再考虑翻转

注意:在区间推平后应当取消区间翻转标记。

因为当 $reverse=1$ 的时候,实际上是先翻转、再赋为指定值,因此实际上翻转已经完成然后被推平了,因此不应该翻转。

但这又是个问题了,为什么不能是 $set$ 先有值然后再有 $reverse\leftarrow 1$ 呢?

其实是有的,但是这也不需要翻转啊……大家都一样,翻转完和没翻没有区别,因此不做考虑

其实还有一些问题,见下文的区间推平


对于赋值,没什么好说的,见代码。

对于翻转,注意不要去交换左右子树即可。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
void down(int p){
    if(t[p].set!=-1){
        t[p<<1].cnt[t[p].set]=size(p<<1);
        t[p<<1].cnt[t[p].set^1]=0;
        t[p<<1].len[t[p].set]=size(p<<1);
        t[p<<1].len[t[p].set^1]=0;
        t[p<<1].pre[t[p].set]=size(p<<1);
        t[p<<1].pre[t[p].set^1]=0;
        t[p<<1].suf[t[p].set]=size(p<<1);
        t[p<<1].suf[t[p].set^1]=0;
        t[p<<1].set=t[p].set;
        t[p<<1].reverse=false;

        t[p<<1|1].cnt[t[p].set]=size(p<<1|1);
        t[p<<1|1].cnt[t[p].set^1]=0;
        t[p<<1|1].len[t[p].set]=size(p<<1|1);
        t[p<<1|1].len[t[p].set^1]=0;
        t[p<<1|1].pre[t[p].set]=size(p<<1|1);
        t[p<<1|1].pre[t[p].set^1]=0;
        t[p<<1|1].suf[t[p].set]=size(p<<1|1);
        t[p<<1|1].suf[t[p].set^1]=0;
        t[p<<1|1].set=t[p].set;
        t[p<<1|1].reverse=false;

        t[p].set=-1;
    }

    if(t[p].reverse){
        swap(t[p<<1].cnt[0],t[p<<1].cnt[1]);
        swap(t[p<<1].len[0],t[p<<1].len[1]);
        swap(t[p<<1].pre[0],t[p<<1].pre[1]);
        swap(t[p<<1].suf[0],t[p<<1].suf[1]);
        t[p<<1].reverse=!t[p<<1].reverse;

        swap(t[p<<1|1].cnt[0],t[p<<1|1].cnt[1]);
        swap(t[p<<1|1].len[0],t[p<<1|1].len[1]);
        swap(t[p<<1|1].pre[0],t[p<<1|1].pre[1]);
        swap(t[p<<1|1].suf[0],t[p<<1|1].suf[1]);
        t[p<<1|1].reverse=!t[p<<1|1].reverse;

        t[p].reverse=false;
    }
}

区间推平

注意清除翻转标记。

然后也没什么好说的,见代码。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void set(int p,int l,int r,int k){
    if(l<=t[p].l&&t[p].r<=r){
        t[p].cnt[k]=size(p);
        t[p].cnt[k^1]=0;
        t[p].len[k]=size(p);
        t[p].len[k^1]=0;
        t[p].pre[k]=size(p);
        t[p].pre[k^1]=0;
        t[p].suf[k]=size(p);
        t[p].suf[k^1]=0;

        t[p].set=k;
        t[p].reverse=false;//注意清除!
        return;
    }
    down(p);
    if(l<=t[p<<1].r)set(p<<1,l,r,k);
    if(t[p<<1|1].l<=r)set(p<<1|1,l,r,k);
    up(p);
}

区间翻转

唯一需要注意的就是,$reverse$ 是 bool 变量,逻辑运算符 ! 取反即可。

然后就是直接交换四个信息。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void reverse(int p,int l,int r){
    if(l<=t[p].l&&t[p].r<=r){
        swap(t[p].cnt[0],t[p].cnt[1]);
        swap(t[p].len[0],t[p].len[1]);
        swap(t[p].pre[0],t[p].pre[1]);
        swap(t[p].suf[0],t[p].suf[1]);
        t[p].reverse=!t[p].reverse;
        return;
    }
    down(p);
    if(l<=t[p<<1].r)reverse(p<<1,l,r);
    if(t[p<<1|1].l<=r)reverse(p<<1|1,l,r);
    up(p);
}

查询区间内 $1$ 的个数

就是求 $cnt_1$ 意义上的区间和。

参考代码

1
2
3
4
5
6
7
8
int query(int p,int l,int r,int k){
    if(l<=t[p].l&&t[p].r<=r)return t[p].cnt[k];
    down(p);
    int ans=0;
    if(l<=t[p<<1].r)ans=query(p<<1,l,r,k);
    if(t[p<<1|1].l<=r)ans+=query(p<<1|1,l,r,k);
    return ans;
}

查询区间内最多连续 $1$ 的个数

首先在不考虑“跨越中线”的情况下,即求 $len$ 意义下的区间最大值。

然后我们考虑跨越中线,类似的也就是把左子树的后缀与右子树的前缀长度加起来就行,注意记得取 $\min$,防止超出给定区间

参考代码

1
2
3
4
5
6
7
8
9
int query_len(int p,int l,int r,int k){
    if(l<=t[p].l&&t[p].r<=r)return t[p].len[k];
    down(p);
    int ans=0;
    if(l<=t[p<<1].r)ans=query_len(p<<1,l,r,k);
    if(t[p<<1|1].l<=r)ans=max(ans,query_len(p<<1|1,l,r,k));
    ans=max(ans,min(t[p<<1].suf[k],t[p<<1].r-l+1) + min(t[p<<1|1].pre[k],r-t[p<<1|1].l+1));
    return ans;
}

AC 代码

整整 $5\text{KB}$……

但是实话实说,虽然看起来长而且实际上也长,但是很多都是“复制粘贴 + 微调”可以解决的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
const int N=1e5;
int n,m,a[N+1];
struct seg{
	struct node{
		int l,r;
		int cnt[2],len[2],pre[2],suf[2];
		int set;
		bool reverse;
	}t[4*N+1];
	
	void up(int p){
		t[p].cnt[0]=t[p<<1].cnt[0]+t[p<<1|1].cnt[0];
		t[p].cnt[1]=t[p<<1].cnt[1]+t[p<<1|1].cnt[1];
		t[p].len[0]=max(t[p<<1].len[0],t[p<<1|1].len[0]);
		t[p].len[0]=max(t[p].len[0],t[p<<1].suf[0]+t[p<<1|1].pre[0]);
		t[p].len[1]=max(t[p<<1].len[1],t[p<<1|1].len[1]);
		t[p].len[1]=max(t[p].len[1],t[p<<1].suf[1]+t[p<<1|1].pre[1]);
		
		t[p].suf[0]=t[p<<1|1].suf[0];
		if(t[p<<1|1].suf[0]==size(p<<1|1))t[p].suf[0]+=t[p<<1].suf[0];
		t[p].suf[1]=t[p<<1|1].suf[1];
		if(t[p<<1|1].suf[1]==size(p<<1|1))t[p].suf[1]+=t[p<<1].suf[1];
		
		t[p].pre[0]=t[p<<1].pre[0];
		if(t[p<<1].pre[0]==size(p<<1))t[p].pre[0]+=t[p<<1|1].pre[0];
		t[p].pre[1]=t[p<<1].pre[1];
		if(t[p<<1].pre[1]==size(p<<1))t[p].pre[1]+=t[p<<1|1].pre[1];
	}
	void build(int p,int l,int r){
		t[p].l=l,t[p].r=r;
		t[p].set=-1;
		if(l==r){
			t[p].cnt[a[l]]=1;
			t[p].len[a[l]]=1;
			t[p].pre[a[l]]=1;
			t[p].suf[a[l]]=1;
			return;
		}
		int mid=l+r>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		up(p);
	}
	int size(int p){
		return t[p].r-t[p].l+1;
	}
	void down(int p){
		if(t[p].set!=-1){
			t[p<<1].cnt[t[p].set]=size(p<<1);
			t[p<<1].cnt[t[p].set^1]=0;
			t[p<<1].len[t[p].set]=size(p<<1);
			t[p<<1].len[t[p].set^1]=0;
			t[p<<1].pre[t[p].set]=size(p<<1);
			t[p<<1].pre[t[p].set^1]=0;
			t[p<<1].suf[t[p].set]=size(p<<1);
			t[p<<1].suf[t[p].set^1]=0;
			t[p<<1].set=t[p].set;
			t[p<<1].reverse=false;
			
			t[p<<1|1].cnt[t[p].set]=size(p<<1|1);
			t[p<<1|1].cnt[t[p].set^1]=0;
			t[p<<1|1].len[t[p].set]=size(p<<1|1);
			t[p<<1|1].len[t[p].set^1]=0;
			t[p<<1|1].pre[t[p].set]=size(p<<1|1);
			t[p<<1|1].pre[t[p].set^1]=0;
			t[p<<1|1].suf[t[p].set]=size(p<<1|1);
			t[p<<1|1].suf[t[p].set^1]=0;
			t[p<<1|1].set=t[p].set;
			t[p<<1|1].reverse=false;
			
			t[p].set=-1;
		}
		
		if(t[p].reverse){
			swap(t[p<<1].cnt[0],t[p<<1].cnt[1]);
			swap(t[p<<1].len[0],t[p<<1].len[1]);
			swap(t[p<<1].pre[0],t[p<<1].pre[1]);
			swap(t[p<<1].suf[0],t[p<<1].suf[1]);
			t[p<<1].reverse=!t[p<<1].reverse;
			
			swap(t[p<<1|1].cnt[0],t[p<<1|1].cnt[1]);
			swap(t[p<<1|1].len[0],t[p<<1|1].len[1]);
			swap(t[p<<1|1].pre[0],t[p<<1|1].pre[1]);
			swap(t[p<<1|1].suf[0],t[p<<1|1].suf[1]);
			t[p<<1|1].reverse=!t[p<<1|1].reverse;
			
			t[p].reverse=false;
		}
	}
	void set(int p,int l,int r,int k){
		if(l<=t[p].l&&t[p].r<=r){
			t[p].cnt[k]=size(p);
			t[p].cnt[k^1]=0;
			t[p].len[k]=size(p);
			t[p].len[k^1]=0;
			t[p].pre[k]=size(p);
			t[p].pre[k^1]=0;
			t[p].suf[k]=size(p);
			t[p].suf[k^1]=0;
			
			t[p].set=k;
			t[p].reverse=false;
			return;
		}
		down(p);
		if(l<=t[p<<1].r)set(p<<1,l,r,k);
		if(t[p<<1|1].l<=r)set(p<<1|1,l,r,k);
		up(p);
	}
	void reverse(int p,int l,int r){
		if(l<=t[p].l&&t[p].r<=r){
			swap(t[p].cnt[0],t[p].cnt[1]);
			swap(t[p].len[0],t[p].len[1]);
			swap(t[p].pre[0],t[p].pre[1]);
			swap(t[p].suf[0],t[p].suf[1]);
			t[p].reverse=!t[p].reverse;
			return;
		}
		down(p);
		if(l<=t[p<<1].r)reverse(p<<1,l,r);
		if(t[p<<1|1].l<=r)reverse(p<<1|1,l,r);
		up(p);
	}
	int query(int p,int l,int r,int k){
		if(l<=t[p].l&&t[p].r<=r)return t[p].cnt[k];
		down(p);
		int ans=0;
		if(l<=t[p<<1].r)ans=query(p<<1,l,r,k);
		if(t[p<<1|1].l<=r)ans+=query(p<<1|1,l,r,k);
		return ans;
	}
	int query_len(int p,int l,int r,int k){
		if(l<=t[p].l&&t[p].r<=r){
			return t[p].len[k];
		}
		down(p);
		int ans=0;
		if(l<=t[p<<1].r){
			ans=query_len(p<<1,l,r,k);
		}
		if(t[p<<1|1].l<=r){
			ans=max(ans,query_len(p<<1|1,l,r,k));
		}
		ans=max(ans,min(t[p<<1].suf[k],t[p<<1].r-l+1) + min(t[p<<1|1].pre[k],r-t[p<<1|1].l+1));
		return ans;
	}
}t;
void solve0(int l,int r){
	t.set(1,l,r,0);
}
void solve1(int l,int r){
	t.set(1,l,r,1);
}
void solve2(int l,int r){
	t.reverse(1,l,r);
}
void solve3(int l,int r){
	printf("%d\n",t.query(1,l,r,1));
}
void solve4(int l,int r){
	printf("%d\n",t.query_len(1,l,r,1));
}
typedef void (*fp)(int,int);
fp solve[5]={solve0,solve1,solve2,solve3,solve4};//函数指针
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	scanf("%d %d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",a+i);
	t.build(1,1,n);
	while(m--){
		int op,l,r;
		scanf("%d %d %d",&op,&l,&r);
		l++,r++;
		solve[op](l,r);
	}
	
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}