ABC127-F Absolute Minima
二通りの方法で通した。
- 優先度付キュー
- Binary indexed tree + 二分探索
始め思いついたのは2つ目の方法。 これには座標圧縮なども必要になり大変だった。 1つ目の方法は公式pdfを参考にしたもの。
共通して必要な考察
|x-ai| + bi を順に足していって任意のタイミングで最小値と最小を取るxを求める問題。
|x-a1|+|x-a2|+...+|x-ak|+b1+b2+... +bk の 最小値はどう求めたらいいでしょうか。
|x-ai|が正になる項と負になる項の数が同じなら、絶対値記号を外したとき+xと-xの数が同じになるので定数項のみが残る。 aiでソートしておいて、(大きい方半分のaiの和)- (小さい方半分のaiの和)+ sum(bi) を計算すればいい。 (kが奇数の時は中間値のaiを無視して前半のk//2こと後半のk//2このみ見る。)
また最小を取るxは偶数の時はa{k//2} < x < a{k//2 + 1}の任意のxとなり、奇数の時はx = a_{k//2 + 1}となる。(aiはソート済みとして1-indexとする。)
ソートしながらaiをリストに格納していき、k//2番目の要素にアクセスできること、また、その要素までのsumを取ること、ができればこの問題は解ける。
解法1. 優先度付キュー
解説pdf通りに解いた。頭のいい解法だと感じた。
優先度付キューではソートしながら値を格納することが可能だが、最小値しか取得できない。 そこで優先度付キューを2つ用意して小さい方半分を扱うリストと、大きい方半分を扱うリストを用意することで解決できる。 また、普通にやると全体の要素数が偶数の時と奇数の時で場合分けが必要だが、一度に値を2つ更新することでこれも解決できる。 例えば、1, 3, 4と数値を入れた後、2つのリストはこうなる。
[1, 1, 3] [3, 4, 4]
次に2を入れると
[1, 1, 2, 2] [3, 3, 4, 4]
次に5を入れると
[1, 1, 3, 3] [4, 4, 5, 5]
というように境界をずらしながら入れていくと良い。 全体の要素数が奇数の時も、偶数のように扱えるのが頭いいポイントだと思う。
また、合計値を取るのも前回からの差分のみを取れば良いが、これも結構悩んだ。解説pdf参照。
from heapq import * import sys;input=sys.stdin.readline Q, = map(int, input().split()) qs = [] c = 0 for i in range(Q): q = input().split() if len(q) == 1: c += 1 else: a, b = int(q[1]), int(q[2]) qs.append((a, b, c)) c = 0 if c: qs.append((a, b, c)) R = [10**18] L = [10**18] ss = 0 for a, b, c in qs: xs = [] if c: x = heappop(L) for _ in range(c): print(-x, ss) heappush(L, x) x = heappop(R) y = -heappop(L) if x < a: heappush(R, a) heappush(R, a) heappush(L, -x) heappush(L, -y) ss += a-x elif y > a: heappush(L, -a) heappush(L, -a) heappush(R, x) heappush(R, y) ss += y-a else: heappush(L, -a) heappush(R, a) heappush(L, -y) heappush(R, x) ss += b
解法2. Binary indexed tree + 二分探索
とりあえずコードだけ...
class Bit: def __init__(self, n): self.size = n self.tree = [0] * (n + 1) def sum(self, i): s = 0 while i > 0: s += self.tree[i] i -= i & -i return s def add(self, i, x): while i <= self.size: self.tree[i] += x i += i & -i def bsearch(mn, mx, func): #func(i)=True を満たす最大のi (mn<=i<mx) idx = (mx + mn)//2 while mx-mn>1: if func(idx): idx, mn = (idx + mx)//2, idx continue idx, mx = (idx + mn)//2, idx return idx import sys;input=sys.stdin.readline Q, = map(int, input().split()) qs = [] c = 0 j = 0 for _ in range(Q): q = input().split() if len(q) == 1: c += 1 else: a, b = int(q[1]), int(q[2]) qs.append((a, b, c, j)) c = 0 j+=1 if c: qs.append((a, b, c, j)) m = j qqs = sorted(qs, key=lambda x:x[0]) qid2i = dict() i2a = dict() for i, (a, b, c, qid) in enumerate(qqs): j2i[qid] = i+1 i2a[i+1] = a m = len(j2i) st = Bit(m) st_sum = Bit(m) asum=bsum=0 for a, b, c, i in qs: if c: j = bsearch(0, m+1, lambda x: st.sum(x)<ci) for _ in range(c): if not i%2: print(i2a[j+1], -2*st_sum.sum(j+1)+asum+bsum) else: print(i2a[j+1], -st_sum.sum(j)+(asum-st_sum.sum(j+1))+bsum) asum += a bsum += b st.add(j2i[i], 1) st_sum.add(j2i[i], a) ci = -(-(i+1)//2)