Solution -「HDU」Ridiculous Netizens

cirnovsky /

§ Desc.

给定一棵 NN 个节点无根树,找出满足以下条件的集合 SS 的数量:

  • S{1,,n}S \subseteq \{1,\dots,n\}
  • SS 的导出子图联通;
  • vSavM\displaystyle\prod_{v \in S} a_v \leqslant M

§ Sol.

点分治,统计包括当前分治中心的集合数量,如果从子树的角度入手会发现并不好做——合并这一步就卡死了。考虑以 DFN 为状态,设 f(i,j)f(i,j) 表示在子树中 DFN 序排后 ii 个节点中选出了乘积为 jj 的集合。这个状态实际上是很浪费空间的,那么使用根号分治,另令 g(i,j)g(i, j) 表示乘积 Mj\frac{M}{j} 时的方案数,这样就开得下了。

const int MN = 2e3, B = 1e3;
int n, m, a[MN + 100], f[MN + 100][B + 100], g[MN + 100][B + 100];
int sz[MN + 100], res, rev[MN + 100], Time, mxsb[MN + 100], rt;
bool vis[MN + 100];
vvi grp;
void getsz(int u, int Fu) {
    sz[u] = 1; for (int v : grp[u]) if (v != Fu && !vis[v]) getsz(v, u), sz[u] += sz[v];
}
void getrt(int u, int Fu, int all) {
    mxsb[u] = all-sz[u];
    for (int v : grp[u]) if (v != Fu && !vis[v]) {
        getrt(v, u, all); chkmax(mxsb[u], sz[v]);
    }
    if (rt == 0 || mxsb[u] < mxsb[rt]) rt = u;
}
void dfs(int u, int Fu) {
    rev[Time++] = u;
    for (int v : grp[u]) if (v != Fu && !vis[v]) dfs(v, u);
}
void solve(int u) {
    rt = 0; getsz(u, n); getrt(u, n, sz[u]); vis[rt] = 1; Time = 0; dfs(rt, n); getsz(rt, n);
    rep(Time+1) {
        memset(f[i], 0, sizeof f[i]);
        memset(g[i], 0, sizeof g[i]);
    }
    f[Time][1] = 1;
    drep(i,Time-1,0) {
        int w = a[rev[i]];
        // Putting @i into the component.
        rep(j,1,B+1) {
            if (j <= B / w) add_eq(f[i][j * w], f[i+1][j]);
            else if ((m / j) / w > 0) add_eq(g[i][(m / j) / w], f[i+1][j]);
        }
        rep(j,w,B+1) add_eq(g[i][j/w], g[i+1][j]);
        // Putting @i out of the component, skipping its subtree.
        rep(j,1,B+1) {
            add_eq(f[i][j], f[i+sz[rev[i]]][j]);
            add_eq(g[i][j], g[i+sz[rev[i]]][j]);
        }
    }
    rep(i,1,min(B, m)+1) add_eq(res, add(f[0][i], g[0][i]));
    rep(i,min(B, m),B+1) add_eq(res, g[0][i]);
    sub_eq(res, 1);
    for (int v : grp[rt]) if (!vis[v]) solve(v);
}
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(nullptr);
    cin >> n >> m;
    grp = vvi(n);
    rep(n) cin >> a[i];
    rep(n-1) {int u,v; cin >> u >> v; u--; v--; grp[u].pb(v); grp[v].pb(u);}
    solve(0); cout << res << "\n";
}