常用树上函数

给定k个点·,判断他们能不能被同一条链覆盖。(也就是能不能构成简单路径)

Sol:找到链的两个端点p1,p2,链上的点x一定满足$lca(p1,x)=x,lca(p2,x)=lca(p1,p2)$或者$lca(p2,x)=x,lca(p1,x)=lca(p1,p2)$

auto checkchain = [&](vector<int> &node) {
        int p1 = 0, p2 = 0;
        deb(lcarmq.dep);
        for (auto x : node) {
            if (p1 == 0 || lcarmq.dep[p1] < lcarmq.dep[x])
                p1 = x;
        }  // 找到最深端点
        assert(p1 > 0);
        for (auto x : node) {
            if (lcarmq.lca(x, p1) != x) {
                if (p2 == 0 || lcarmq.dep[p2] < lcarmq.dep[x])
                    p2 = x;
            }
        }  // 找到另一侧得最深端点
        if (p2 == 0)
            return true;  // 一条全是祖先关系的链,不存在转折点;
        int ho = lcarmq.lca(p1, p2);
        for (auto x : node) {
            if (x == ho)
                continue;
            int tmp1 = lcarmq.lca(x, p1), tmp2 = lcarmq.lca(x, p2);
            bool flag = 0;
            if ((tmp1 == x && tmp2 == ho) || (tmp1 == ho && tmp2 == x))
                flag = true;
            if (flag == 0)
                return false;
        }
        return true;
    };

前置题目:给定一棵树。从根出发把所有点都走一遍的最短距离。

Sol:先考虑最后回到根节点,则答案就是2倍边权和$2\sum w$。最后停在一个最深节点即可,省去回溯那段。

struct edge {
    int v, w;
};
void solve() {
    int n, rt;
    cin >> n >> rt;
    vector<vector<edge>> e(n + 1);
    vector<int> mxlen(n + 1);
    int ans = 0;
    int sum = 0;
    for (int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        e[u].push_back({v, w});
        e[v].push_back({u, w});
        sum += w;
    }
    function<void(int, int)> dfs = [&](int u, int fa) {
        for (auto [v, w] : e[u]) {
            if (v == fa)
                continue;
            dfs(v, u);
            if (mxlen[u] < mxlen[v] + w)
                mxlen[u] = mxlen[v] + w;
        }
        deb(u, mxlen[u]);
    };
    dfs(rt, rt);
    deb(rt, mxlen[rt]);
    ans = 2 * sum - mxlen[rt];
    cout << ans << endl;
}

给定树上k个点,再给起点st和终点ed。要求从st出发最终走到ed。必须经过这k个点,求最短距离。

Sol:贪心考虑肯定是先访问除y以外的子树,最后访问完y的子树再回到y。考虑以x为根思考问题,这样贪心可以节约dep[y]的长度。剩下的点怎么计算?考虑直接让每个点暴力向上跳并且标记每个点,直到跳到有标记的点have_vis,这表示从have_vis到根已经其他点贡献过了。进入回溯每次+2.

void solve() {
    int n, k, st, ed;
    cin >> n >> k >> st >> ed;
    vector<vector<int>> e(n + 1);
    vector<bool> vis(n + 1);
    vector<int> par(n + 1);
    vector<int> node;
    for (int i = 1; i <= k; i++) {
        int x;
        cin >> x;
        node.push_back(x);
    }
    node.push_back(ed);
    for (int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    //------------------------------------
    vector<int> dep(n + 1);
    function<void(int, int)> dfs = [&](int u, int fa) {
        for (auto v : e[u]) {
            if (v == fa)
                continue;
            dep[v] = dep[u] + 1;
            par[v] = u;
            deb(u, v);
            dfs(v, u);
        }
        deb(u, dep[u]);
    };
    dfs(st, st);
    int ans = -dep[ed];
    vis[st] = 1;
    for (auto x : node) {
        while (vis[x] == 0) {
            vis[x] = true;
            x = par[x];
            ans += 2;
        }
    }
    cout << ans << endl;
}