핵심 아이디어

<aside> 💡

밑에 있는 것을 위로 올리게 되면 원래 겹치던 것과 겹치지 않게 되는 대신 겹치지 않던 것과 겹치게 된다.

$1\le i\le N$인 $i$에 대해 $i$번째 부터 전부 위로 올리는 경우에 대해 겹치게되는 개수를 구한다.

이를 왼쪽, 오른쪽에 모두에 대해 구한다.

</aside>

오른쪽을 위로 올리는 경우는 왼쪽을 아래로 내리는 효과가 있기 때문에 서로 다른 경우다.


코드

#include <bits/stdc++.h>
#define FASTIO ios_base::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
using namespace std;
typedef long long int ll;

class Seg
{
    private:
        int n;
        vector<ll> tree;

        void upd(int node, int start, int end, int idx, int val)
        {
            if(idx<start || end<idx) return;
            if(start == end){
                tree[node]++;
                return;
            }

            int mid = start+end>>1;
            upd(2*node, start, mid, idx, val);
            upd(2*node+1, mid+1, end, idx, val);

            tree[node] = tree[2*node]+tree[2*node+1];
        }

        ll qry(int node, int start, int end, int left, int right)
        {
            if(end<left || right<start) return 0;
            if(left<=start && end<=right) return tree[node];

            int mid = start+end>>1;
            ll l = qry(2*node, start, mid, left, right);
            ll r = qry(2*node+1, mid+1, end, left, right);

            return l+r;
        }

    public:
        Seg(int n)
        {
            this->n = n;
            tree.resize(4*n);
        }

        void upd(int idx, int val){upd(1, 1, n, idx, val);}
        ll qry(int l, int r) {return qry(1, 1, n, l, r);}

};

int n;

ll solve(vector<int> &fxd, vector<int> &trsd)
{
    vector<int> idx(n+1);
    vector<int> a(n+1);

    Seg cnt(n);
    for(int i = 1; i<=n; i++){
        idx[fxd[i]] = i;
    }

    ll cur = 0;
    for(int i = 1; i<=n; i++){
        cur += cnt.qry(idx[trsd[i]], n);
        cnt.upd(idx[trsd[i]], 1);
        
        a[i] = idx[trsd[i]];
    }

    ll ans = cur;
    for(int i = 1; i<=n; i++){
        cur = cur -(a[i]-1) + (n-a[i]);
        ans = min(ans, cur);
    }

    return ans;
}

signed main()
{
    FASTIO;
    cin >> n;
    vector<int> l(n+1);
    vector<int> r(n+1);
    for(int i = 1; i<=n; i++) cin >> l[i];
    for(int i = 1; i<=n; i++) cin >> r[i];

    ll ans = ll(1e+18);
    ans = min(ans, solve(l, r));
    ans = min(ans, solve(r, l));

    cout << ans <<"\\n";
    return 0;
}