這個題很不錯哦,用到了最短路+二分答案+dfs,出題人太厲害了
首先,100個點,總共的點的標號數目可能達到100,壓縮不了,那就只能dfs了,但肯定需要剪枝
我自己想到的剪枝就是,先不管標號的問題,從終點做一次最短路,記錄路徑,如果源點不可達,則無解,如果這條路徑上的點剛好標號都不一樣,則輸出到源點的最短路徑
然後就暴力dfs,如果當前長度加上不考慮標號時當前點到終點的最短路徑都大於當前最優解,就返回,這樣還是T了
後來想能否快速找到一個可行解作為上限值,但不會找,查瞭解題報告才知道,可以二分答案,相當於有目的地找多次可行解,然後去驗證,很巧妙
這題最讓我覺得不可思議的就是每次dfs到一個點的時候都去求一次該點到終點的最短路,用來剪枝,好神奇啊
還有一個值得注意的地方就是,dfs的時候,在進入某個點之前把它標記,然後dfs回來時再撤銷標記,否則,在進入某個點後標記,退出這個點時在撤銷,可能會由於剪枝而忘記撤銷標記的情況
代碼:
#include<iostream>#include<cstring>#include<cstdio>#include<algorithm>#include<cmath>#include<stack>#include<queue>#include<vector>#include<map>#include<ctime>using namespace std;const int MAX=1005;const int inf=1<<26;struct node { int v,w,next;}g[MAX*100];int adj[MAX],e,vis1[MAX],vis2[MAX],kind[MAX],n,m;int dis[MAX],fa[MAX],pre[MAX];int flag[MAX];bool pos[MAX],found;void add(int u,int v,int w){ g[e].v=v; g[e].w=w; g[e].next=adj[u]; adj[u]=e++;}void spfa(int s,int t,int k){ int i,u,v; queue<int>que; for(i=0;i<=n;i++) dis[i]=inf; if(k) memset(pre,-1,sizeof(pre)); dis[s]=0; memset(vis1,0,sizeof(vis1)); vis1[s]=1; que.push(s); while(!que.empty()) { u=que.front(); que.pop(); vis1[u]=0; for(i=adj[u];i!=-1;i=g[i].next) { v=g[i].v; if(vis2[kind[v]]) continue; //if(kind[v]==kind[t]) //continue; //if(kind[v]==kind[s]&&v!=s) //continue; if(dis[v]>dis[u]+g[i].w) { dis[v]=dis[u]+g[i].w; pre[v]=u; if(!vis1[v]) { vis1[v]=1; que.push(v); } } } }} bool check(){ for(int i=0;i<MAX;i++) if(flag[i]>1) return false; return true;}bool dfs(int u,int now,int limit,int t){ if(now>limit) return false; if(u==t) return true; spfa(u,t,0); if(now+dis[t]>limit) return false; int i,v; for(i=adj[u];i!=-1;i=g[i].next) { v=g[i].v; if(vis2[kind[v]]) continue; vis2[kind[v]]=1; if(dfs(v,now+g[i].w,limit,t)) return true; vis2[kind[v]]=0; } return false;}void solve(int s,int t,int sum){ int l=1,r=sum,ans=-1,mid; while(l<=r) { mid=(l+r)>>1; memset(vis2,0,sizeof(vis2)); vis2[kind[s]]=1; if(dfs(s,0,mid,t)) { ans=mid; r=mid-1; } else l=mid+1; } printf("%d\n",ans);}inline int nextInt(){int res = 0; char ch;bool neg = false;while (ch = getchar(), (ch < '0' || ch > '9') && ch != '-');if (ch == '-') neg = true;else res = ch - '0';while (ch = getchar(), ch >= '0' && ch <= '9') res = res * 10 + ch - '0';if (neg) res = - res;return res;}int main(){ int i,j,k,T,s,t,sum=0; scanf("%d",&T); while(T--) { scanf("%d%d%d%d",&n,&m,&s,&t); memset(adj,-1,sizeof(adj)); e=0; while(m--) { //scanf("%d%d%d",&i,&j,&k); i=nextInt(); j=nextInt(); k=nextInt(); add(i,j,k); add(j,i,k); sum+=k; } for(i=0;i<n;i++) { //scanf("%d",&kind[i]); kind[i]=nextInt(); } if(s==t) { puts("0"); continue; } memset(vis2,0,sizeof(vis2)); vis2[kind[t]]=1; spfa(t,s,1); memset(flag,0,sizeof(flag)); if(pre[s]==-1) { puts("-1"); continue; } for(i=s;i!=-1;i=pre[i]) { flag[kind[i]]++; } if(check()) { printf("%d\n",dis[s]); continue; } solve(s,t,sum); } return 0;}