西电校赛2022D题——数位相邻

题意:

给你一个 1e6 位数 n 和一个数 k ,要求你找出最小的数,满足:

  1. 大于 n
  2. 没有相邻两个数位的差为 k

0k90\leq k\leq 9

这题第一眼看上去有一个很美妙的做法:

首先让 n+1 ,这样就把限制转化大于等于 n ,更好解决。

然后从高位往低位看,如果有数位差为 k ,就把较低位 +1 ,然后更低位全变 0 。

这样做会有许多特例:

  1. 较低位是 9 怎么办
  2. 较高位也是 9 怎么办
  3. k=0 怎么办
  4. k=1 怎么办

沿着这些特例往下想,其实会遇到更多问题。

当然这些问题仔细想想也是能想清的,标解就是特例特判。但是在考试现场去扣这些特例很可能会陷入智商和心态双重炸裂的不利境地。

因此我选择使用容易证明正确性的数位 dp 来做这道题。

数位 dp 题虽然一般认为也很难,但这道题只有下界且限制条件比较简单,因此还是很好做的。

首先还是让 n+1 转化为大于等于的情况。

接着设置状态:令 f[i][j][k]f[i][j][k] 表示在只考虑第 1 到 i 位的情况下,第 i 位选 j 且自由状态为 k 的方案是否可行。自由状态就是指当更高位已经大于原数,本位可以随意选取时,k=1 ;当更高位还未大于原数,本位只能大于等于原数时,k=0 。

状态转移方程很好解决:

(1) k=1 时

由于本位已经自由,更低位自然也自由,因此

f[i][j][1]=j=09f[i1][j][1]f[i][j][1]=\sum_{j'=0}^{9}f[i-1][j'][1]

(2) k=0 时

此时可以以 j 为依据分成三个部分。

  1. j<a[i]j<a[i] ,不可能
  2. j=a[i]j=a[i] ,只能由更低位的非自由态( k=0 转移来)
  3. j>a[i]j>a[i] ,可以由更低位的自由态转移来

因此有

{f[i][j][0]=0j<a[i]f[i][j][0]=j=09f[i1][j][0]j=a[i]f[i][j][0]=j=09f[i1][j][1]j>a[i].\begin{cases}{} f[i][j][0]=0 & j<a[i] \\ \\ f[i][j][0]=\sum_{j'=0}^{9}f[i-1][j'][0] & j=a[i] \\ \\ f[i][j][0]=\sum_{j'=0}^{9}f[i-1][j'][1] & j>a[i] \end{cases} .

初始状态:f[0]f[0] 的所有数位的自由态和非自由态都可行。

这样裸写出来的 dp 需要进行 100×n100\times n 次计算,因此很容易被卡常。

实际上在想特例构造方案的时候就可以发现,答案中相当大的数位其实都属于以下两种情况:

  1. 和原数位一样
  2. 要么是 0 要么是 1

而且只要 a[i]a[i] 是 0 可以,那么 a[i+1]a[i+1] 除了 k 不可以,剩下都可以,完全没必要再枚举 9 个数加一遍。

所以还可以加一个小优化:

把 k=1 和 k=0 的情况分开计算,一旦可以使 f[i][j][k]f[i][j][k] 为 1 ,就立刻 break ,不再枚举后面的。

这样就可以卡过了

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include<bits/stdc++.h>
using namespace std;
int n,m,o;
char s[1100000];
int a[1100000];
int f[1100000][10][2];
void rvs(){
for(int i=1;i*2<=n;i++){
swap(a[i],a[n-i+1]);
}
}
void fwd(){
for(int i=1;i<=n;i++){
if(a[i]>=10){
a[i+1]+=a[i]/10;
a[i]%=10;
if(i==n){
n++;
}
}
}
for(;a[n]>=10;n++){
a[n+1]+=a[n]/10;
a[n]%=10;
}
}
void dp(){
for(int i=0;i<10;i++){
f[0][i][0]=1;
f[0][i][1]=1;
}
for(int i=1;i<=n+1;i++){
for(int j=0;j<10;j++){
for(int k=0;k<10;k++){
if(abs(j-k)!=m && (f[i-1][k][0] | f[i-1][k][1])){
f[i][j][1]=1;
break;
}
}
}
for(int j=a[i]+1;j<10;j++){
for(int k=0;k<10;k++){
if(abs(j-k)!=m && f[i-1][k][1]){
f[i][j][0]=1;
break;
}
}
}
for(int k=0;k<10;k++){
if(abs(a[i]-k)!=m && f[i-1][k][0]){
f[i][a[i]][0]=1;
break;
}
}
}
}
void otp(int x){
int y=-100,z=0;
for(int i=x;i>=1;i--){
for(int j=0;j<10;j++){
if(abs(j-y)==m){
continue;
}
if(z==0){
if(j>=a[i] && f[i][j][0]==1){
y=j;
if(j>a[i]){
z=1;
}
break;
}
}
else{
if(f[i][j][1]==1){
y=j;
break;
}
}
}
printf("%d",y);
}
printf("\n");
}
int main(){
scanf("%d",&o);
while(o --> 0){
scanf("%s%d",s+1,&m);
n=strlen(s+1);
for(int i=1;i<=n;i++){
a[i]=s[i]-'0';
}
rvs();
a[1]++;
fwd();
dp();
int t=0;
for(int i=a[n];i<10;i++){
if(f[n][i][0]==1){
t=n;
break;
}
}
if(t==0){
t=n+1;
}
otp(t);
for(int i=1;i<=n+1;i++){
a[i]=0;
for(int j=0;j<10;j++){
f[i][j][0]=0;
f[i][j][1]=0;
}
}
}
return 0;
}