2018.11.05【校內模擬】規避(最短路計數)(容斥)(正難則反)
阿新 • • 發佈:2018-12-19
傳送門
解析:
首先直接統計並不好做,考慮反著做,先求出總共的方案數,然後減去相遇的方案數。
總方案數就是到的最短路數量的平方(兩人分別作選擇)。
首先這是個計數類問題,先做一個最短路計數。
令表示到的最短路長度,表示到的最短路數量,和同理,記最短路長度為
怎麼統計相遇? 首先我們發現兩人的相交位置一定是最短路的中點,這個中點可能是點也可能是邊,所以考慮以中點為標誌統計答案。
當一個點不在任何一條最短路上時,直接。 否則若這個點滿足,則兩人有可能在這個點上相遇,方案數為,因為兩人要走完全程,又必須經過點,根據乘法原理可以輕易得出答案。
如果這個點滿足,存在一條邊,且,那麼兩人就可能在這條邊上相遇。那麼經過這條邊的最短路數就是,根據乘法原理,平方一下就好了。
程式碼:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const
inline int getint(){
re int num;
re char c;
while(!isdigit(c=gc()));num=c^48;
while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
return num;
}
cs ll mod=1000000007;
cs int N=100005,M=200005;
int last[N],nxt[M<<1],to[M<<1],ecnt;
int w[M<<1];
inline void addedge(int u,int v,int val){
nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v,w[ecnt]=val;
nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u,w[ecnt]=val;
}
ll distS[N],distT[N];
ll cntS[N],cntT[N],tot,ans;
bool flag;
bool vis[N];
set<pair<ll,int> > q;
inline void Dijkstra(ll *cs dist,ll *cs cnt,int S,int T){
memset(dist,0x3f,sizeof distS);
dist[S]=0;cnt[S]=1;
q.clear();
q.insert(make_pair(0,S));
memset(vis,0,sizeof vis);
while(!q.empty()){
int u=q.begin()->second;
q.erase(q.begin());
if(vis[u])continue;
vis[u]=true;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(dist[v]>dist[u]+w[e]){
q.erase(make_pair(dist[v],v));
dist[v]=dist[u]+w[e];
cnt[v]=0;
q.insert(make_pair(dist[v],v));
}
if(dist[v]==dist[u]+w[e])cnt[v]=(cnt[v]+cnt[u])%mod;
}
}
}
int S,T;
int n,m;
signed main(){
n=getint();
m=getint();
S=getint();
T=getint();
for(int re i=1;i<=m;++i){
int u=getint(),v=getint();
ll val=getint();
if(v!=u)
addedge(u,v,val);
}
Dijkstra(distS,cntS,S,T);
Dijkstra(distT,cntT,T,S);
tot=distS[T];
ans=cntS[T]*cntT[S]%mod;
for(int re u=1;u<=n;++u){
if(distS[u]+distT[u]!=tot)continue;
if((!(tot&1))&&distS[u]==(tot>>1)){
ans=(ans-cntS[u]*cntT[u]%mod*cntS[u]%mod*cntT[u]%mod+mod)%mod;
continue;
}
if(distS[u]*2>tot)continue;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(distS[v]!=distS[u]+w[e])continue;
if(distT[v]*2>=tot)continue;
if(tot!=distS[u]+w[e]+distT[v])continue;
ans=(ans-cntS[u]*cntT[v]%mod*cntS[u]%mod*cntT[v]%mod+mod)%mod;
}
}
cout<<(ans%mod+mod)%mod;
return 0;
}