핵심 아이디어

<aside> 💡

다익스트라 진행 하면서 $S$부터 현재 정점 $i$까지의 최단거리 위에 있던 정점들 중 최솟값을 $\text{curMin}_i$를 저장한다.

후에 $S$에서 $T$까지의 최단 거리 위에 정점들에 대해서 DFS를 돌며 $a_i + \text{curMin}_i$를 정답에 갱신한다.

</aside>


코드

#include <bits/stdc++.h>
#define FASTIO ios_base::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
using namespace std;
typedef long long ll;
#define INF 9187201950435737471

int N, M, S, T;
int a[1'000'001];
vector<pair<int, int>> adj[1'000'001];

priority_queue<tuple<ll, int, int>, vector<tuple<ll, int, int>>, greater<>> pq;
ll dist[1'000'001];

int tans[1'000'001];
int curMn[1'000'001];

int ans = 1e+9;
bool reach[1'000'001];
bool v[1'000'001];

bool dfs(int cur)
{
    v[cur] = true;
    
    if(cur == T){
        ans = min(ans, tans[cur]);
        reach[cur] = true;
        return true;
    }

    bool flag = false;
    for(auto[nxt, nw]: adj[cur]){
        if(dist[cur]+nw != dist[nxt]) continue;

        if(!v[nxt]){
            if(dfs(nxt)){
                flag = true;
                reach[cur] = true;
                ans = min(ans, tans[cur]);
            }
        }
        else{
            if(reach[nxt]){
                ans = min(ans, tans[cur]);
                flag = true;
            }
        }
    }

    if(flag) return true;
    else return false;
}

signed main()
{
    FASTIO;
    cin >> N >> M;
    for(int i = 1; i<=N; i++) cin >> a[i];
    cin >> S >> T;
    for(int i = 0; i<M; i++){
        int s, e, w;
        cin >> s >> e >> w;
        adj[s].push_back({e, w});
    }

    memset(dist, 0x7f, sizeof(dist));
    memset(tans, 0x7f, sizeof(tans));
    memset(curMn, 0x7f, sizeof(curMn));

    dist[S] = 0;
    curMn[S] = 1e+8;
    
    pq.push({0, S, 1e+8});
    while(!pq.empty()){
        auto[d, cur, mn] = pq.top();
        pq.pop();

        if(dist[cur] != d) continue;
        if(curMn[cur] != mn) continue;

        tans[cur] = min(tans[cur], a[cur]+mn);

        for(auto[nxt, nw]: adj[cur]){
            if(dist[nxt]< d+nw) continue;
            
            if(dist[nxt]>d+nw){
                dist[nxt] = d+nw;
                curMn[nxt] = min(mn, a[cur]);
                pq.push({dist[nxt], nxt, curMn[nxt]});
            }
            else if(dist[nxt] == d+nw){
                if(curMn[nxt] > min(mn, a[cur])){
                    curMn[nxt] = min(mn, a[cur]);
                    pq.push({dist[nxt], nxt, curMn[nxt]});    
                }
            }
        }
    }

    dfs(S);
    if(ans == 1000000000) cout <<"-1";
    else cout << ans <<"\\n";
    return 0;
}