少ない学びをせめて記録する

技術記録、競プロメモ、その他調べたことを書く @京都, twitter : @nehan_der_thal

AtCoder ABC158 メモ

ABCDを解きました。水色パフォでした。今週はコドフォも含めて全体的にいつもより順位が低めでした。練習の仕方に問題があるのかも。

D String Formation

後ろから解きました。Dequeと使うのが簡単なようですが私は最終的な文字列の長さ分リストを用意しておいて、前に置くときはX[i] = x;i+=1;後ろに置くときはX[j] = x;j-=1のようにしました。 Dequeは最終的な長さがわからないときにこそ有効と思っているのですが、バグの元なので多少オーバーキルでもDequeを使ったほうがいいかもしれないです。

E Divisible Substring

わかりませんでした。かなり難しく感じます。modがなお苦手なんだなあという感想です。終了後、解説通りにACしました。 mod系の問題をたくさんといて克服したいです。

F Removing Robots

Fにしては簡単だった気がします。Eを早めに見切りつけてれば解けた気もします。 終了後、自力ACしました。 各ロボットの位置をソートして添字を振り直します。各ロボットiの位置はX1, X2,...Xnになります。

  1. 各ロボットiが直接影響するロボットの区間はi+1..Ciとなるはずです。Ciを求めます。

  2. 後ろからdpします。dp[i] = ロボットi...Nだけについて考えて最終的に残るロボットの数です。 これはdp[i] = dp[i+1] + dp[max(dp[i+1:Ci + 1])]で更新できます。(1つ目の項はiのスイッチを入れなかったときで、2つめはスイッチを入れたときです。それぞれiがいないかいるかなので独立で足せます。)

1でCiを求めるために苦戦してしまい間に合いませんでした。 xi, xi+diを混ぜてリストに入れて、ソートして順に見ていってxj+djが来たら直前のxに相当するiをC[j]=iとして計算。という風にしていました。 以下コードです。

class SegTree:
    def __init__(self, init_val, ide_ele, seg_func):
        self.segfunc = seg_func
        n = len(init_val)
        self.num = 2**(n-1).bit_length()
        self.ide_ele = ide_ele
        self.seg=[self.ide_ele]*2*self.num
        for i in range(n):
            self.seg[i+self.num-1]=init_val[i]    
        for i in range(self.num-2,-1,-1) :
            self.seg[i]=self.segfunc(self.seg[2*i+1],self.seg[2*i+2]) 
        
    def update(self, k, x):
        k += self.num-1
        self.seg[k] = x
        # k+1ではなくkでは?
        while k+1:
            k = (k-1)//2
            self.seg[k] = self.segfunc(self.seg[k*2+1],self.seg[k*2+2])
        
    def query(self, p, q):
        if q<=p:
            return self.ide_ele
        p += self.num-1
        q += self.num-2
        res=self.ide_ele
        while q-p>1:
            if p&1 == 0:
                res = self.segfunc(res,self.seg[p])
            if q&1 == 1:
                res = self.segfunc(res,self.seg[q])
                q -= 1
            p = p//2
            q = (q-1)//2
        if p == q:
            res = self.segfunc(res,self.seg[p])
        else:
            res = self.segfunc(self.segfunc(res,self.seg[p]),self.seg[q])
        return res
import sys;input=sys.stdin.readline
mod = 998244353
N=int(input())
X = []
Y = []
for _ in range(N):
    x, d = map(int, input().split())
    X.append((x,d))
X.sort(key=lambda x:x[0])
for i, (x, d) in enumerate(X):
    Y.append((2*x,i,0))
    Y.append((2*x+2*d-1,i,1))
Y.sort(key=lambda x:x[0])
C = [0]*N
for x, i, t in Y:
    if t==0:
        j = i
    elif t==1:
        C[i] = j
st = SegTree(C, -1, max)
dp = [1]*(N+1)
for i in range(N-1, -1, -1):
    y=st.query(i, C[i]+1)
    st.update(i, y)
    dp[i] = (dp[i+1]+dp[y+1]) % mod
print(dp[0])

Ciを求めるパートは単に二分探索の方が良かったかもしれないです。 二分探索でバグったことが多すぎて使うのを躊躇っている節があります。 二分探索使って書き直しました。

class SegTree:
    def __init__(self, init_val, ide_ele, seg_func):
        self.segfunc = seg_func
        n = len(init_val)
        self.num = 2**(n-1).bit_length()
        self.ide_ele = ide_ele
        self.seg=[self.ide_ele]*2*self.num
        for i in range(n):
            self.seg[i+self.num-1]=init_val[i]    
        for i in range(self.num-2,-1,-1) :
            self.seg[i]=self.segfunc(self.seg[2*i+1],self.seg[2*i+2]) 
        
    def update(self, k, x):
        k += self.num-1
        self.seg[k] = x
        # k+1ではなくkでは?
        while k+1:
            k = (k-1)//2
            self.seg[k] = self.segfunc(self.seg[k*2+1],self.seg[k*2+2])
        
    def query(self, p, q):
        if q<=p:
            return self.ide_ele
        p += self.num-1
        q += self.num-2
        res=self.ide_ele
        while q-p>1:
            if p&1 == 0:
                res = self.segfunc(res,self.seg[p])
            if q&1 == 1:
                res = self.segfunc(res,self.seg[q])
                q -= 1
            p = p//2
            q = (q-1)//2
        if p == q:
            res = self.segfunc(res,self.seg[p])
        else:
            res = self.segfunc(self.segfunc(res,self.seg[p]),self.seg[q])
        return res

def bsearch(mn, mx, func):
    #func(i)=False を満たす最大のi
    idx = (mx + mn)//2
    while mx-mn>1:
        if func(idx):
            idx, mx = (idx + mn)//2, idx
            continue
        idx, mn = (idx + mx)//2, idx
    return idx

import sys;input=sys.stdin.readline
mod = 998244353
N=int(input())
X = []
for _ in range(N):
    x, d = map(int, input().split())
    X.append((x,d))
X.sort()
C = [0] * N
for i in range(N):
    x, d = X[i]
    k = bsearch(-1, N, lambda j: X[j][0] >= x+d)
    C[i] = k

st = SegTree(C, -1, max)
dp = [1]*(N+1)
for i in range(N-1, -1, -1):
    y=st.query(i, C[i]+1)
    st.update(i, y)
    dp[i] = (dp[i+1]+dp[y+1]) % mod
print(dp[0])