题目描述

你有个长度为 nn 的数组 aa 和一个长度为 nn 的排列 pp,对于每一个 ii 有一有向边 (i,pi)(i,p_i)

有如下三种操作:

  • 1 l r 询问 i=lrai\sum_{i=l}^r a_i

  • 2 v x 将所有 vv 能到达的节点所对应编号的值加 xx

  • 3 x y 交换 pxp_xpyp_y

对于每一 11 操作输出结果。

1n,m2×1051 \le n,m \le 2\times 10^5。8s

思路

数据范围和 8s 实现提醒我们考虑根号算法。序列分块感觉没有什么性质,考虑将操作按 BB 分成 mB\frac{m}{B} 块。对于一块中的询问,只会有 O(B)\mathcal{O}(B) 个位置被改变。我们只需要将不被改变的位置连续的点缩掉,就只需要维护一个 O(B)\mathcal{O}(B) 点的环。这个缩点的过程是 O(n)\mathcal{O}(n) 的。

于是考虑这三个操作。三操作暴力交换即可,时间复杂度 O(B)\mathcal{O}(B) 。二操作只需要在缩好的点上打标记即可,时间复杂度 O(B)\mathcal{O}(B) 。一操作我们只需要算增量,在三操作上对于每个块二分一下有多少被包含即可,时间复杂度 O(BlogB)\mathcal{O}(B\log B)。于是算得总时间复杂度为 O(mB(n+B+BlogB))\mathcal{O}(\frac{m}{B}(n + B + B \log B))。假设 m=Θ(n)m = \Theta(n),取 B=O(nlogn)B = \mathcal{O}(\sqrt{\frac{n}{\log n}}) 时间复杂度为 O(nnlogn)\mathcal{O}(n \sqrt{n \log n})。这样已经可以过了。代码也是这么实现的。

其实可以更快。可以对于缩点后的集合记录 1i1 \sim i 的点分别出现多少次,可以 O(B2)\mathcal{O}(B^2) 预处理,将总时间复杂度降到了 O(nn)\mathcal{O}(n \sqrt n)

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
using ll = long long;

struct _ {
    int op, l, r;
} q[N];

int n, m, p[N];
ll sum[N], tag[N], a[N];
int vis[N], id[N], nxt[N], la[N], ip[N];
vector<int> G[N];

int main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= n; i++) cin >> p[i], ip[p[i]] = i;
    cin >> m;
    int B = 1000;
    for (int ql = 1, qr; ql <= m; ql += B) {
        for (int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i];
        memset(vis, false, sizeof vis); memset(tag, 0, sizeof tag); memset(id, 0, sizeof id);
        for (int i = 1; i <= n; i++) G[i].clear();
        qr = min(m, ql + B - 1);    
        int qwq = qr - ql + 1;
        for (int i = 1; i <= qwq; i++) {
            cin >> q[i].op >> q[i].l >> q[i].r;
            if (q[i].op == 2) vis[q[i].l] = 1;
            if (q[i].op == 3) vis[q[i].l] = 1, vis[q[i].r] = 1;
        }
        int cnt = 0;
        static int idx = 0;
        ++idx;
        for (int i = 1; i <= n; i++) {
            if (vis[i]) {
                int x = i; cnt++;
                x = ip[i];
                id[i] = cnt;
                while (!vis[x]) {
                    id[x] = cnt;
                    x = ip[x];  
                }
            }
        }
        for (int i = 1; i <= n; i++)
            if (vis[i])
                nxt[id[i]] = id[p[i]];
        for (int i = 1; i <= n; i++) G[id[i]].push_back(i);
        for (int i = 1; i <= qwq; i++) {
            if (q[i].op == 1) {
                ll res = 0;
                for (int j = 1; j <= cnt; j++)
                    res += 1ll * tag[j] * max(0ll, 1ll * (lower_bound(G[j].begin(), G[j].end(), q[i].r + 1) - G[j].begin() - (lower_bound(G[j].begin(), G[j].end(), q[i].l) - G[j].begin())));
                cout << sum[q[i].r] - sum[q[i].l - 1] + res << endl;
            } else  if (q[i].op == 2) {
                int x = id[q[i].l];
                idx++;
                while (la[x] != idx && x) {
                    la[x] = idx;
                    tag[x] += q[i].r;
                    x = nxt[x];
                }
            } else {
                swap(nxt[id[q[i].l]], nxt[id[q[i].r]]);
                swap(ip[p[q[i].l]], ip[p[q[i].r]]);
                swap(p[q[i].l], p[q[i].r]);
            }
        }
        for (int i = 1; i <= n; i++) a[i] += tag[id[i]];
    } 
    return 0;
}