Press ESC to close

如何正确地排序 题解

如何正确地排序 题解

LZDQ

·

2022-03-26 20:15:16

·

题解

不难想到枚举每一行,算最大值出现在这一行的贡献。如果有多个最大值则强制取最前面的。

算第 k 行的贡献时,对每一列令所有数都减去这一列第 k 行的数,这样子一列同时减对于求最大值没有很大影响,并且把第 k 行变成了 0。那么一对 (i,j) 会在第 k 行取到最大值的条件是剩下 m-1 行的和都小于或小于等于 0(能不能取等看是不是没有强制取最前面的限制),不难发现这是 m-1 个大小关系,可以用偏序解决。在 m=4 的时候使用 cdq 分治可以做到 O(n\log^2 n)。

不过这题可以 O(n \log n)。我们只考虑 m=4 的情况。首先需要 \rm min-max 容斥(我一开始觉得这个思路很 sb,结果还真能做),把最小值给反演掉。4 个数的最小值,等于“选 1 个的最大值 - 选 2 个的最大值 + 选 3 个的最大值 - 选 4 个的最大值”。惊人的事情发生了,“减去选 4 个的最大值” 与原本的最大值抵消了。于是答案就等于“选 1 个的最大值 - 选 2 个的最大值 + 选 3 个的最大值”,都可以单 log 解决(不过常数会变大)。

#include

#include

#include

using namespace std;

typedef long long ll;

typedef pair pr;

const int MAXN=2e5+5;

int n,m,r,a[5][MAXN];

struct node{

int a[4];

bool q;

int& operator [](int x){

return a[x];

}

}b[MAXN*2],c[MAXN*2];

bool operator <(node a,node b){

if(a[1]==b[1]){

if(r==1) return !a.q&&b.q;

return a.q&&!b.q;

}

return a[1]

}

pr rsum[MAXN<<2];

inline void add(int w,int x,int c){

for(int i=w+MAXN*2; i

rsum[i].first+=c*x;

rsum[i].second+=c;

}

}

inline pr query(int w){

pr res(0ll,0);

for(int i=w+MAXN*2; i; i&=i-1)

res.first+=rsum[i].first,res.second+=rsum[i].second;

return res;

}

ll res,ans;

ll Calc(){

res=0;

for(r=1; r<=m; r++){

for(int k=1; k

for(int i=1; i<=n; i++)

b[i][k]=a[k][i]-a[r][i];

for(int k=r+1; k<=m; k++)

for(int i=1; i<=n; i++)

b[i][k-1]=a[k][i]-a[r][i];

for(int i=1; i<=n; i++){

b[i].q=0;

b[i+n].q=1;

b[i][0]=b[i+n][0]=a[r][i];

for(int k=1; k

b[i+n][k]=-b[i][k];

}

sort(b+1,b+n*2+1);

if(m==3){

memset(rsum,0,sizeof(rsum));

for(int i=1; i<=n*2; i++)

if(b[i].q){

pr t=query(b[i][2]-(r>=3));

res+=t.first;

res+=1ll*t.second*b[i][0];

}else add(b[i][2],b[i][0],1);

}else{

ll sum=0,tot=0;

for(int i=1; i<=n*2; i++)

if(b[i].q) res+=sum+tot*b[i][0];

else sum+=b[i][0],tot++;

}

}

return res;

}

int main(){

scanf("%d%d",&m,&n);

for(int i=1; i<=m; i++)

for(int j=1; j<=n; j++)

scanf("%d",a[i]+j);

if(m==2){

for(int i=1; i<=m; i++)

for(int j=1; j<=n; j++)

ans+=a[i][j];

ans*=n*2;

}else if(m==3){

ans=Calc();

for(int i=1; i<=m; i++)

for(int j=1; j<=n; j++)

a[i][j]=-a[i][j];

ans-=Calc();

}else{

for(int i=1; i<=m; i++)

for(int j=1; j<=n; j++)

ans+=a[i][j];

ans*=n*2;

static int _a[5][MAXN];

memcpy(_a,a,sizeof(a));

for(int k=0; k<16; k++){

m=0;

for(int i=1; i<=4; i++)

if(k&1<

m++;

for(int j=1; j<=n; j++)

a[m][j]=_a[i][j];

}

if(m==2) ans-=Calc();

else if(m==3) ans+=Calc();

}

}

printf("%lld\n",ans);

return 0;

}