1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
| #pragma GCC optimize(2) #pragma GCC optimize(3,"Ofast","inline") #include<iostream> #include<cstring> using namespace std; const int N=3e5+10; typedef long long LL; const LL INF=1e18; int h[N],e[N],ne[N],root[N],_idx,n; LL ans[N],a[N]; void add(int a,int b){ e[_idx]=b,ne[_idx]=h[a],h[a]=_idx++; } struct node{ int son[2],p; int size; LL v,min,sum; }tr[N]; int idx; int get_node(int v,int p){ idx++; tr[idx].v=v,tr[idx].p=p; tr[idx].sum=0,tr[idx].min=INF; tr[idx].size=1; return idx; } LL calc(int u,int v){ return (tr[u].v-tr[v].v)*(tr[u].v-tr[v].v); } void pushup(int u){ int l=tr[u].son[0],r=tr[u].son[1]; if(tr[u].p)tr[u].min=min(tr[u].min,calc(u,tr[u].p)); if(l)tr[u].min=min(tr[u].min,calc(u,l)); if(r)tr[u].min=min(tr[u].min,calc(u,r)); tr[u].sum=tr[u].min+tr[l].sum+tr[r].sum; tr[u].size=tr[l].size+tr[r].size+1; } void rotate(int x){ int y=tr[x].p,z=tr[y].p; int k=(x==tr[y].son[1]); tr[z].son[y==tr[z].son[1]]=x,tr[x].p=z; tr[y].son[k]=tr[x].son[k^1],tr[tr[x].son[k^1]].p=y; tr[x].son[k^1]=y,tr[y].p=x; pushup(y),pushup(x); } void splay(int& root,int x,int k){ while(tr[x].p!=k){ int y=tr[x].p,z=tr[y].p; if(z!=k){ if((tr[z].son[1]==y)^(tr[y].son[1]==x))rotate(x); else rotate(y); } rotate(x); } if(!k)root=x; } void insert(int p,int& root){ int u=root,fa=0; while(u){ if(tr[u].v<=tr[p].v)fa=u,u=tr[u].son[1]; else fa=u,u=tr[u].son[0]; } if(!fa)root=p; else tr[fa].son[tr[fa].v<=tr[p].v]=p; tr[p].p=fa; splay(root,p,0); } void Tmerge(int &p, int &root) { if(!p)return; Tmerge(tr[p].son[0], root); Tmerge(tr[p].son[1], root); tr[p].son[0]=tr[p].son[1]=0; tr[p].size=1,tr[p].p=0; tr[p].min=INF,tr[p].sum=0; insert(p,root); } void dfs(int u){ int ptr=get_node(a[u],0); insert(ptr,root[u]); for(int i=h[u];~i;i=ne[i]){ int j=e[i]; dfs(j); if(tr[root[u]].size<tr[root[j]].size)swap(root[u],root[j]); Tmerge(root[j],root[u]); } ans[u]=tr[root[u]].sum; } int main(){ int p; cin>>n; memset(h,-1,sizeof h); for(int i=2;i<=n;i++){ scanf("%d",&p); add(p,i); } for(int i=1;i<=n;i++) scanf("%lld",&a[i]); dfs(1); for(int i=1;i<=n;i++) printf("%lld\n",ans[i]); }
|