Solution -「CF 1608F」MEX counting

cirnovsky /

link。

首先考虑暴力,枚举规划前缀 [1,i][1, i] 和前缀 mex xx,则我们需要 xx 个数来填了 [0,x)[0, x),还剩下 ixi-x 个数随便填 [0,x)(x,n][0, x) \cup (x, n],如果直接组合数算可能会出问题,考虑 dp。

定义 f[i][x][j]f[i][x][j] 表示规划前缀 [1,i][1, i],当前的 mex 为 xx,还有 jj 个数当前不对 mex 的取值造成影响(也就是说他们都大于 xx,这 jj 个数是指在 aa 数组中的,所以我们不必关心这 jj 个数具体是什么)。转移就分两种情况:

  • a[i+1]xa[i+1] \neq x(i+1,x,j)(i+1,x,j)+(i,x,j)(x+j)(i+1, x, j) \gets (i+1, x, j)+(i, x, j)*(x+j)(i+1,x,j+1)(i,x,j)(i+1, x, j+1) \gets (i, x, j),第一个就是考虑当加入的 a[i+1]a[i+1] 属于那 jj 个数或者属于 [0,x)[0, x),第二个很简单;
  • a[i+1]=xa[i+1] = x:设当前 mex 变成了 yy,则有 (i+1,y,jy+x+1)(i+1,y,jy+x+1)+(i,x,j)×(jyx1)×(yx1)!(i+1, y, j-y+x+1) \gets (i+1, y, j-y+x+1)+(i, x, j) \times \binom{j}{y-x-1} \times (y-x-1)!,意义明显,注意后面那个是排列数而不是组合数。

然后这个是 O(n2k2)O(n^2k^2) 的,把刷表改成填表后前缀和优化即可。

using modint = modint998244353;
int n, K, a[2100];
modint dp[2][2100][2100], sum[2][2100][2100], fct[2100], ifct[2100];
inline int L(int i) { return max(0, a[i]-K); }
inline int R(int i) { return min(i, a[i]+K); }
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    fct[0] = 1;
    for (int i=1; i<2100; ++i) {
        fct[i] = fct[i-1]*i;
    }
    ifct[2099] = fct[2099].inv();
    for (int i=2098; i>=0; --i) {
        ifct[i] = ifct[i+1]*(i+1);
    }
    cin >> n >> K;
    for (int i=1; i<=n; ++i) {
        cin >> a[i];
    }
    dp[0][0][0] = sum[0][0][0] = 1;
    for (int i=1,cur=1; i<=n; ++i,cur^=1) {
        for (int j=0; j<=i; ++j) {
            for (int x=L(i); x<=R(i) && x<=j; ++x) {
                dp[cur][x][j] = dp[cur^1][x][j]*j;
                if (j) {
                    dp[cur][x][j] += dp[cur^1][x][j-1];
                }
                if (j && x) {
                    dp[cur][x][j] += sum[cur^1][min(x-1, R(i-1))][j-1]*ifct[j-x];
                }
                sum[cur][x][j] = dp[cur][x][j]*fct[j-x];
                if (x) {
                    sum[cur][x][j] += sum[cur][x-1][j];
                }
            }
        }
        for (int j=0; j<=i-1; ++j) {
            for (int x=L(i-1); x<=R(i-1) && x<=j; ++x) {
                dp[cur^1][x][j] = sum[cur^1][x][j] = 0;
            }
        }
    }
    modint ans = 0;
    for (int i=0; i<=n; ++i) {
        for (int j=L(n); j<=R(n) && j<=i; ++j) {
            ans += dp[n&1][j][i]*fct[n-j]*ifct[n-i];
        }
    }
    cout << ans.val() << "\n";
}