如何正确地排序 题解
LZDQ
·
2022-03-26 20:15:16
·
题解
不难想到枚举每一行,算最大值出现在这一行的贡献。如果有多个最大值则强制取最前面的。
算第 k 行的贡献时,对每一列令所有数都减去这一列第 k 行的数,这样子一列同时减对于求最大值没有很大影响,并且把第 k 行变成了 0。那么一对 (i,j) 会在第 k 行取到最大值的条件是剩下 m-1 行的和都小于或小于等于 0(能不能取等看是不是没有强制取最前面的限制),不难发现这是 m-1 个大小关系,可以用偏序解决。在 m=4 的时候使用 cdq 分治可以做到 O(n\log^2 n)。
不过这题可以 O(n \log n)。我们只考虑 m=4 的情况。首先需要 \rm min-max 容斥(我一开始觉得这个思路很 sb,结果还真能做),把最小值给反演掉。4 个数的最小值,等于“选 1 个的最大值 - 选 2 个的最大值 + 选 3 个的最大值 - 选 4 个的最大值”。惊人的事情发生了,“减去选 4 个的最大值” 与原本的最大值抵消了。于是答案就等于“选 1 个的最大值 - 选 2 个的最大值 + 选 3 个的最大值”,都可以单 log 解决(不过常数会变大)。
#include
#include
#include
using namespace std;
typedef long long ll;
typedef pair
const int MAXN=2e5+5;
int n,m,r,a[5][MAXN];
struct node{
int a[4];
bool q;
int& operator [](int x){
return a[x];
}
}b[MAXN*2],c[MAXN*2];
bool operator <(node a,node b){
if(a[1]==b[1]){
if(r==1) return !a.q&&b.q;
return a.q&&!b.q;
}
return a[1]
}
pr rsum[MAXN<<2];
inline void add(int w,int x,int c){
for(int i=w+MAXN*2; i rsum[i].first+=c*x; rsum[i].second+=c; } } inline pr query(int w){ pr res(0ll,0); for(int i=w+MAXN*2; i; i&=i-1) res.first+=rsum[i].first,res.second+=rsum[i].second; return res; } ll res,ans; ll Calc(){ res=0; for(r=1; r<=m; r++){ for(int k=1; k for(int i=1; i<=n; i++) b[i][k]=a[k][i]-a[r][i]; for(int k=r+1; k<=m; k++) for(int i=1; i<=n; i++) b[i][k-1]=a[k][i]-a[r][i]; for(int i=1; i<=n; i++){ b[i].q=0; b[i+n].q=1; b[i][0]=b[i+n][0]=a[r][i]; for(int k=1; k b[i+n][k]=-b[i][k]; } sort(b+1,b+n*2+1); if(m==3){ memset(rsum,0,sizeof(rsum)); for(int i=1; i<=n*2; i++) if(b[i].q){ pr t=query(b[i][2]-(r>=3)); res+=t.first; res+=1ll*t.second*b[i][0]; }else add(b[i][2],b[i][0],1); }else{ ll sum=0,tot=0; for(int i=1; i<=n*2; i++) if(b[i].q) res+=sum+tot*b[i][0]; else sum+=b[i][0],tot++; } } return res; } int main(){ scanf("%d%d",&m,&n); for(int i=1; i<=m; i++) for(int j=1; j<=n; j++) scanf("%d",a[i]+j); if(m==2){ for(int i=1; i<=m; i++) for(int j=1; j<=n; j++) ans+=a[i][j]; ans*=n*2; }else if(m==3){ ans=Calc(); for(int i=1; i<=m; i++) for(int j=1; j<=n; j++) a[i][j]=-a[i][j]; ans-=Calc(); }else{ for(int i=1; i<=m; i++) for(int j=1; j<=n; j++) ans+=a[i][j]; ans*=n*2; static int _a[5][MAXN]; memcpy(_a,a,sizeof(a)); for(int k=0; k<16; k++){ m=0; for(int i=1; i<=4; i++) if(k&1< m++; for(int j=1; j<=n; j++) a[m][j]=_a[i][j]; } if(m==2) ans-=Calc(); else if(m==3) ans+=Calc(); } } printf("%lld\n",ans); return 0; }