1. 程式人生 > 實用技巧 >HDU 4916 Count on the path 樹形dp

HDU 4916 Count on the path 樹形dp

題意:給出一棵樹,f(a,b)表示不在a,b路徑上的最小點的編號。q組詢問,每次給出u,v,詢問f(u,v)

題解:顯然如果兩個點的lca不是1,那麼答案就是1。可以把1號節點當作根做一次dfs,然後去掉這個點,形成了一片森林。在森林中,預處理出\(mn\_id[i]\)表示i子樹內最小點的編號:\(mn\_id[u]=min(u,mn\_id[v])\)

第二次dfs處理出\(dp[i]\)表示不在i到新根節點的路上的最小點編號。

\(dp[u]=min(mn\_id[v],u的兄弟中最小的mn\_id)\)

怎麼得到u的兄弟中最小的mn_id呢,其實就是在遞迴之前記錄一下u的父親fa的所有兒子的mn_id中排名前兩小的值,當mn_id[u]==mn1時就傳入mn2,否則傳入mn1

查詢的時候,先用並查集的方法檢查a和b是否在一個連通塊,在的話顯然lca不等於1,答案為1。

如果不在同一個連通塊,說明lca為1,那麼就要在它們兩棵子樹中取一個最小值,即取\(min(dp[a],dp[b])\)

但是除了這兩個連通塊外,還要考慮圖中其他連通塊的最小值。所以可以在事先處理出來mn_id[rt]最小的前三個值。最後輸出答案的時,先取dp[a],dp[b]中的最小值,再和u和v所屬連通塊以外的,屬於前3小的連通塊的mn_id[rt]取min。

#include<cstdio>
#include<algorithm>
using namespace std;

const int N=1e6+10;
int n,q,pos,head[N];
struct edge
{
	int to,next;
}e[N<<1];
void add(int x,int y)
{
	e[++pos].next=head[x];
	e[pos].to=y;
	head[x]=pos;
}
int read()
{
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
	return x*f; 
}
void update(int &x,int y){
	if(x>y) x=y;
}
struct node
{
	int d,bel;
	bool operator < (const node &x)const{
		return d<x.d;
	}
	bool operator != (const int &x)const{
		return bel!=x; 
	}
	void clear(){d=1e9;}
}b[4];
int fa[N],bel[N];
int mn_id[N],dp[N];
void dfs(int u,int f)
{
	mn_id[u]=u;
	fa[u]=f;
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].to;if(v==f)continue;
		if(u!=1) bel[v]=bel[u];
		dfs(v,u);
		update(mn_id[u],mn_id[v]);
	}
}
void dfs1(int u,int res)
{
	dp[u]=res;
	int mn1=1e9,mn2=1e9;
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].to;if(v==fa[u])continue;
		update(dp[u],mn_id[v]);
		if(mn_id[v]<mn1) mn2=mn1,mn1=mn_id[v];
		else if(mn_id[v]<mn2) mn2=mn_id[v];
	}
	for(int i=head[u];i;i=e[i].next)
	{
		int nxt;
		int v=e[i].to;if(v==fa[u])continue;
		if(mn1==mn_id[v]) nxt=mn2;
		else nxt=mn1;
		dfs1(v,min(res,nxt));
	}
}
void init()
{
	dp[1]=1e9;pos=0;
	for(int i=1;i<=n;i++)head[i]=0,bel[i]=i;
}

int main()
{
	while(~scanf("%d%d",&n,&q))
	{
		init();
		for(int i=1;i<n;i++)
		{
			int a,b;scanf("%d%d",&a,&b);
			add(a,b),add(b,a);
		} 
		dfs(1,0);
		for(int i=1;i<=3;i++) b[i].clear();
		for(int i=head[1];i;i=e[i].next)
		{
			int v=e[i].to;
			dfs1(v,1e9);
			node P;P.d=mn_id[v],P.bel=v;
			if(P<b[1]) b[3]=b[2],b[2]=b[1],b[1]=P;
			else if(P<b[2]) b[3]=b[2],b[2]=P;
			else if(P<b[3]) b[3]=P;
		}
		int lastans=0;
		while(q--)
		{
			int u,v;scanf("%d%d",&u,&v);
			u=u^lastans,v=v^lastans;
			if(u==v&&u==1)
			{
				puts("2");
				continue;
			}
			if(bel[u]==bel[v])
			{
				puts("1");
				lastans=1;
				continue;
			}
			lastans=min(dp[u],dp[v]);
			u=bel[u],v=bel[v];
			if(b[1]!=u&&b[1]!=v) update(lastans,b[1].d);
			else if(b[2]!=u&&b[2]!=v) update(lastans,b[2].d);
			else if(b[3]!=u&&b[3]!=v) update(lastans,b[3].d);
			printf("%d\n",lastans);
		}
	}
	return 0;
}