常用树上函数
title: 常用树上函数
categories:
- ICPC
tags:
- null
abbrlink: a5678561
date: 2023-09-12 00:00:00
常用树上函数
给定k个点·,判断他们能不能被同一条链覆盖。(也就是能不能构成简单路径)
Sol:找到链的两个端点p1,p2,链上的点x一定满足$lca(p1,x)=x,lca(p2,x)=lca(p1,p2)$或者$lca(p2,x)=x,lca(p1,x)=lca(p1,p2)$
auto checkchain = [&](vector<int> &node) {
int p1 = 0, p2 = 0;
deb(lcarmq.dep);
for (auto x : node) {
if (p1 == 0 || lcarmq.dep[p1] < lcarmq.dep[x])
p1 = x;
} // 找到最深端点
assert(p1 > 0);
for (auto x : node) {
if (lcarmq.lca(x, p1) != x) {
if (p2 == 0 || lcarmq.dep[p2] < lcarmq.dep[x])
p2 = x;
}
} // 找到另一侧得最深端点
if (p2 == 0)
return true; // 一条全是祖先关系的链,不存在转折点;
int ho = lcarmq.lca(p1, p2);
for (auto x : node) {
if (x == ho)
continue;
int tmp1 = lcarmq.lca(x, p1), tmp2 = lcarmq.lca(x, p2);
bool flag = 0;
if ((tmp1 == x && tmp2 == ho) || (tmp1 == ho && tmp2 == x))
flag = true;
if (flag == 0)
return false;
}
return true;
};
前置题目:给定一棵树。从根出发把所有点都走一遍的最短距离。
Sol:先考虑最后回到根节点,则答案就是2倍边权和$2\sum w$。最后停在一个最深节点即可,省去回溯那段。
struct edge {
int v, w;
};
void solve() {
int n, rt;
cin >> n >> rt;
vector<vector<edge>> e(n + 1);
vector<int> mxlen(n + 1);
int ans = 0;
int sum = 0;
for (int i = 1; i <= n - 1; i++) {
int u, v, w;
cin >> u >> v >> w;
e[u].push_back({v, w});
e[v].push_back({u, w});
sum += w;
}
function<void(int, int)> dfs = [&](int u, int fa) {
for (auto [v, w] : e[u]) {
if (v == fa)
continue;
dfs(v, u);
if (mxlen[u] < mxlen[v] + w)
mxlen[u] = mxlen[v] + w;
}
deb(u, mxlen[u]);
};
dfs(rt, rt);
deb(rt, mxlen[rt]);
ans = 2 * sum - mxlen[rt];
cout << ans << endl;
}
给定树上k个点,再给起点st和终点ed。要求从st出发最终走到ed。必须经过这k个点,求最短距离。
Sol:贪心考虑肯定是先访问除y以外的子树,最后访问完y的子树再回到y。考虑以x为根思考问题,这样贪心可以节约dep[y]的长度。剩下的点怎么计算?考虑直接让每个点暴力向上跳并且标记每个点,直到跳到有标记的点have_vis,这表示从have_vis到根已经其他点贡献过了。进入回溯每次+2.
void solve() {
int n, k, st, ed;
cin >> n >> k >> st >> ed;
vector<vector<int>> e(n + 1);
vector<bool> vis(n + 1);
vector<int> par(n + 1);
vector<int> node;
for (int i = 1; i <= k; i++) {
int x;
cin >> x;
node.push_back(x);
}
node.push_back(ed);
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
//------------------------------------
vector<int> dep(n + 1);
function<void(int, int)> dfs = [&](int u, int fa) {
for (auto v : e[u]) {
if (v == fa)
continue;
dep[v] = dep[u] + 1;
par[v] = u;
deb(u, v);
dfs(v, u);
}
deb(u, dep[u]);
};
dfs(st, st);
int ans = -dep[ed];
vis[st] = 1;
for (auto x : node) {
while (vis[x] == 0) {
vis[x] = true;
x = par[x];
ans += 2;
}
}
cout << ans << endl;
}
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 爱飞鱼的blog!