1. 程式人生 > 實用技巧 >題解【[SCOI2016]幸運數字】

題解【[SCOI2016]幸運數字】

\[\texttt{Description} \]

給出一棵包含 \(n\) 個點的樹,點帶權。

\(Q\) 次詢問,每次詢問給出兩個點 \(x\)\(y\),求 \(x\)\(y\) 的簡單路徑上,任意選擇若干個點,使得其點權異或和最大。

\(1 \leq n \leq 2 \times 10^4\)\(1 \leq Q \leq 2 \times 10^5\)\(0 \leq g_i \leq 2^{60}\)

\[\texttt{Solution} \]

前置知識:樹上倍增,線性基。

首先我們知道線性基是可以暴力合併的,即把線性基 \(b\) 的元素一個個地插入線性基 \(a\)

中。定義 \(\text{merge}\) 運算合併兩個線性基,顯然 \(\text{merge}\) 運算的複雜度是 \(\mathcal{O(\log^2 n)}\) 的。

考慮樹上倍增。設 \(f_{i, j}\) 表示節點 \(i\) 向上跳 \(2^j\) 步所到達的節點編號,設 \(g_{i, j}\) 表示節點 \(i\) 向上跳 \(2^j\) 步所經過的所有節點(不包括節點 \(i\))的點權所組成的線性基。則有:

\[f_{i, j} = f_{f_{i, j - 1},j - 1} \]

\[g_{i, j} = g_{i, j - 1} \ \text{merge} \ g_{f_{i, j - 1}, j - 1} \]

通過簡單的 BFS 即可預處理出 \(f\)\(g\)

詢問也按照倍增求 \(\text{lca}\) 的框架,令 \(x, y\) 向上跳,\(x, y\) 每向上移動一段路徑,就合併該路徑對應的線性基。

至此我們有一個 \(\mathcal{O(n \log^3 n) - O(\log^3 n)}\) 的做法。

這個做法還不夠優秀,考慮挖掘線性基合併的一些性質。


有一個樹上倍增的 trick:

若一個運算滿足 \(a \oplus a = a\),則我們稱這個運算滿足 " 可重複貢獻 " 性。

例如 \(\max\)\(\min\)\(\gcd\) 等運算均滿足 " 可重複貢獻 "

性。

考慮 \(x\)\(y\) 之間的路徑,記 \(z = \text{lca}(x, y)\)

我們將 \(x\)\(y\) 的路徑分成了 \({\color{red}紅路徑}\)\({\color{orange}橙路徑}\)\({\color{yellow}黃路徑}\)\({\color{green}綠路徑}\) 四部分,每一部分的長度都是不超過 " 所在大路徑的長度 " 的 \(2\)最大整數次冪(\({\color{red}紅路徑}\)\({\color{orange}橙路徑}\) 歸在 \(x\)\(z\) 的大路徑裡,\({\color{yellow}黃路徑}\)\({\color{green}綠路徑}\) 歸在 \(y\)\(z\) 的大路徑裡),每一部分的貢獻都可以通過倍增陣列求出,故答案為 \({\color{red}紅路徑貢獻} \oplus {\color{orange}橙路徑貢獻} \oplus {\color{yellow}黃路徑貢獻} \oplus {\color{green}綠路徑貢獻}\)

\(\oplus\) 運算的複雜度為 \(\mathcal{O(w)}\),若配合 " 長鏈剖分求樹上 \(k\) 級祖先 " 以及 " tarjan 求 \(\text{lca}\) ",即可 \(\mathcal{O(w)}\) 回答詢問。


回到此題,注意到線性基合併也是滿足 " 可重複貢獻 " 性的。具體的說:若線性基 \(a\) 與線性基 \(b\) 所代表的路徑之間有交集,則線性基 \(a\) 與線性基 \(b\) 經過 \(\text{merge}\) 運算後所得到的線性基代表的路徑為線性基 \(a\) 與線性基 \(b\) 所代表的路徑的並集。因為交集部分的元素在重複插入線性基的時候顯然不會多做貢獻。

於是套用上述 trick 即可將複雜度優化至 \(\mathcal{O(n \log^3 n) - O(\log^2 n)}\)

注意到 \(\text{merge}\) 操作的複雜度為 \(\mathcal{O(\log^2 n)}\),所以我們還是可以用樹上倍增來求 " 樹上 \(k\) 級祖先 " 以及 " \(\text{lca}\) "。

\[\texttt{Code} \]

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;

const int N = 20100, M = 40100;

int logx[N];

int n, m;

long long val[N];

int tot, head[N], ver[M], Next[M];

void add(int u, int v) {
	ver[++ tot] = v;    Next[tot] = head[u];    head[u] = tot;
}

struct bas {
	long long p[64];
	bas() { 
		for (int i = 0; i <= 63; i ++)
			p[i] = 0;
	} 
};

long long calc(bas a) {
	long long ans = 0;
	for (int i = 63; i >= 0; i --)
		if ((ans ^ a.p[i]) > ans) ans ^= a.p[i];
	return ans;
}

bas operator + (bas a, long long x) { 
	for (int i = 63; i >= 0; i --) {
		if (!(x >> i)) continue;
		if (!a.p[i]) { a.p[i] = x; break; }
		else x ^= a.p[i];
	}
	return a;
}

bas operator + (bas a, bas b) {
	for (int i = 63; i >= 0; i --)
		if (b.p[i]) a = a + b.p[i];
	return a;
}

int d[N];
int f[N][20];
bas g[N][20];

void bfs() {
	queue<int> q;
	q.push(1), d[1] = 1;
	while (q.size()) {
		int u = q.front(); q.pop();
		for (int i = head[u]; i; i = Next[i]) {
			int v = ver[i];
			if (d[v]) continue;
			d[v] = d[u] + 1;
			f[v][0] = u;
			g[v][0] = g[v][0] + val[u];
			for (int j = 1; j <= 19; j ++) {
				f[v][j] = f[f[v][j - 1]][j - 1];
				if (f[v][j]) g[v][j] = g[v][j - 1] + g[f[v][j - 1]][j - 1];
			}
			q.push(v);
		}
	}
} 

int lca(int x, int y) {
	if (d[x] > d[y]) swap(x, y);
	for (int i = 19; i >= 0; i --)
		if (d[x] <= d[f[y][i]]) y = f[y][i];
	if (x == y) return x;
	for (int i = 19; i >= 0; i --)
		if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
} 

int jump(int x, int lv) {
	for (int i = 19; i >= 0; i --)
		if (lv >> i & 1) x = f[x][i];
	return x;
}

long long ask(int x, int y) {
	bas ans;
	ans = ans + val[x];
	ans = ans + val[y];
	int L = lca(x, y);
	if (d[x] - d[L] >= 1) {
		int lv = logx[d[x] - d[L]];
		ans = ans + g[x][lv];
		ans = ans + g[jump(x, d[x] - d[L] - (1 << lv))][lv];
	} if (d[y] - d[L] >= 1) {
		int lv = logx[d[y] - d[L]];
		ans = ans + g[y][lv];
		ans = ans + g[jump(y, d[y] - d[L] - (1 << lv))][lv];
	}
	return calc(ans);
}

int main() {
	logx[0] = -1;
	for (int i = 1; i <= 20000; i ++)
		logx[i] = logx[i / 2] + 1;

	scanf("%d%d", &n, &m);

	for (int i = 1; i <= n; i ++)
		scanf("%lld", &val[i]);

	for (int i = 1, u, v; i < n; i ++) {
		scanf("%d%d", &u, &v);
		add(u, v), add(v, u);
	}

	bfs();

	while (m --) {
		int x, y; scanf("%d%d", &x, &y);
		printf("%lld\n", ask(x, y));
	} 

	return 0;
}

\[\texttt{Thanks} \ \texttt{for} \ \texttt{reading} \]