1. 程式人生 > 實用技巧 >【洛谷P6623】樹

【洛谷P6623】樹

題目

題目連結:https://www.luogu.com.cn/problem/P6623
給定一棵 \(n\) 個結點的有根樹 \(T\),結點從 \(1\) 開始編號,根結點為 \(1\) 號結點,每個結點有一個正整數權值 \(v_i\)
\(x\) 號結點的子樹內(包含 \(x\) 自身)的所有結點編號為 \(c_1,c_2,\dots,c_k\),定義 \(x\) 的價值為:

\[val(x)=(v_{c_1}+d(c_1,x)) \oplus (v_{c_2}+d(c_2,x)) \oplus \cdots \oplus (v_{c_k}+d(c_k, x)) \]

其中 \(d(x,y)\)

表示樹上 \(x\) 號結點與 \(y\) 號結點間唯一簡單路徑所包含的邊數,\(d(x,x) = 0\)\(\oplus\) 表示異或運算。
請你求出 \(\sum\limits_{i=1}^n val(i)\) 的結果。

思路

考慮從答案從 \(x\) 的子節點如何轉移到 \(x\) 上來:顯然是每一個節點的權值加一後再異或起來。
把每一個節點的權值 + 到目前根的距離轉成二進位制,由低位到高位扔進一棵 Trie 中,那麼把所有子樹內的點權值加一,其實就是沿著 Trie 邊權為 1 的點走下去,並且把沿路遇到的 0 給變成 1,再把沿路的 1 變為 0。那麼其實就是將遍歷到的節點的左右子樹交換了。
插入自己本身的權值簡單,那麼我們就成功搞定加一操作和插入操作。我們只需要將 \(x\)

的各個子節點的 Trie 合併到 \(x\) 上。類似於線段樹合併。
時間複雜度 \(O(n\log n)\)

程式碼

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;

const int N=525020,LG=25;
int n,tot,a[N],head[N],rt[N];
ll ans;

struct edge
{
	int next,to;
}e[N];

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

struct Trie
{
	int tot,lc[N*LG],rc[N*LG],val[N*LG],size[N*LG];
	
	void pushup(int x,int dep)
	{
		val[x]=val[lc[x]]^val[rc[x]];
		if (size[rc[x]]&1) val[x]^=(1<<dep);
	}
	
	int merge(int x,int y,int dep)
	{
		if (!x || !y) return x+y;
		size[x]+=size[y];
		lc[x]=merge(lc[x],lc[y],dep+1);
		rc[x]=merge(rc[x],rc[y],dep+1);
		pushup(x,dep);
		return x;
	}
	
	void update(int x,int dep)
	{
		if (!x) return;
		swap(lc[x],rc[x]);
		update(lc[x],dep+1);
		pushup(x,dep);
	}
	
	int ins(int x,int val,int dep)
	{
		if (dep>22) return 0;
		if (!x) x=++tot;
		size[x]++;
		if (val&(1<<dep)) rc[x]=ins(rc[x],val,dep+1);
			else lc[x]=ins(lc[x],val,dep+1);
		pushup(x,dep);
		return x;
	}
}trie;

void dfs(int x)
{
	for (int i=head[x];~i;i=e[i].next)
	{
		dfs(e[i].to);
		rt[x]=trie.merge(rt[x],rt[e[i].to],0);
	}
	trie.update(rt[x],0);
	rt[x]=trie.ins(rt[x],a[x],0);
	ans+=trie.val[rt[x]];
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d",&n);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	for (int i=2,x;i<=n;i++)
	{
		scanf("%d",&x);
		add(x,i);
	}
	dfs(1);
	printf("%lld",ans);
	return 0;
}