还算萌,但是代码有些难写……
你首先会想要
int n, m, fa[20][300100], pa[300100], dep[300100], cnt[900100];
int ldf[300100], rdf[300100], dfc, qwq;
vi<int> G[300100];
struct path {
int x, y, sx, sy, lca;
};
void dfs(int x, int pre) {
fa[0][x] = pa[x] = pre;
dep[x] = dep[pre]+1;
ldf[x] = ++dfc;
for (int i=1; i<20; ++i) {
fa[i][x] = fa[i-1][fa[i-1][x]];
}
for (auto y : G[x]) {
if (y != pre) {
dfs(y, x);
}
}
rdf[x] = dfc;
}
int lca(int x, int y) {
if (dep[x] < dep[y]) {
swap(x, y);
}
for (int i=19; i>=0; --i) {
if (dep[fa[i][x]] >= dep[y]) {
x = fa[i][x];
}
}
if (x == y) {
return x;
}
for (int i=19; i>=0; --i) {
if (fa[i][x] != fa[i][y]) {
x = fa[i][x], y = fa[i][y];
}
}
return pa[x];
}
int jump(int x, int d) {
if (d < 0) {
return n+(++qwq);
}
for (int i=19; i>=0; --i) {
if ((d>>i)&1) {
x = fa[i][x];
}
}
return x;
}
int bit[300100];
void add(int x, int y) {
for (; x<=n; x+=x&-x) {
bit[x] += y;
}
}
int qry(int x) {
int res = 0;
for (; x; x-=x&-x) {
res += bit[x];
}
return res;
}
int qry(int l, int r) {
return qry(r)-qry(l-1);
}
void executer() {
cin >> n;
for (int i=1,x,y; i<n; ++i) {
cin >> x >> y;
G[x].push_back(y);
G[y].push_back(x);
}
dfs(1, 0);
cin >> m;
vi<path> paths(m);
for (auto& it : paths) {
cin >> it.x >> it.y;
it.lca = lca(it.x, it.y);
it.sx = jump(it.x, dep[it.x]-dep[it.lca]-1);
it.sy = jump(it.y, dep[it.y]-dep[it.lca]-1);
if (it.sx > it.sy) {
swap(it.sx, it.sy), swap(it.x, it.y);
}
}
sort(paths.begin(), paths.end(), [&](const path& lhs, const path& rhs) {
return dep[lhs.lca] < dep[rhs.lca]
|| (dep[lhs.lca] == dep[rhs.lca] && lhs.lca < rhs.lca)
|| (lhs.lca == rhs.lca && lhs.sx > rhs.sx)
|| (lhs.sx == rhs.sx && lhs.sy > rhs.sy);
});
LL ans = 0;
for (int i=0; i<m;) {
int j = i;
while (j < m-1 && paths[j+1].lca == paths[j].lca) {
j++;
}
LL now = 0;
for (int l=i; l<=j;) {
int r = l;
while (r < j && paths[r+1].sx == paths[r].sx) {
r++;
}
for (int k=l; k<=r; ++k) {
ans += now-cnt[paths[k].sy];
}
for (int k=l; k<=r; ++k) {
cnt[paths[k].sx]++, cnt[paths[k].sy]++;
}
now += r-l+1, l = r+1;
}
for (int k=i; k<=j; ++k) {
cnt[paths[k].sx] = cnt[paths[k].sy] = 0;
}
for (int k=i; k<=j; ++k) {
ans += qry(ldf[paths[k].lca], rdf[paths[k].lca]);
if (paths[k].sx <= n) {
ans -= qry(ldf[paths[k].sx], rdf[paths[k].sx]);
}
if (paths[k].sy <= n) {
ans -= qry(ldf[paths[k].sy], rdf[paths[k].sy]);
}
}
for (int k=i; k<=j; ++k) {
add(ldf[paths[k].x], 1), add(ldf[paths[k].y], 1);
}
i = j+1;
}
cout << ans << "\n";
}