Writeups

All writeups

Balsn CTF 2019: unpredictable

2026/01/03

問題ソースコードは以下の通り。

問題ソースコード
import sys
import os
import hashlib
import random
 
version = sys.version.replace('\n', ' ')
print(f'Python {version}')
random.seed(os.urandom(1337))
 
 
for i in range(0x1337):
    print(random.randrange(3133731337))
 
 
# Encrypt flag
sha512 = hashlib.sha512()
for _ in range(1000):
    rnd = random.getrandbits(32)
    sha512.update(str(rnd).encode('ascii'))
 
key = sha512.digest()
 
with open('../flag.txt', 'rb') as f:
    flag = f.read()
 
enc = bytes(a ^ b for a, b in zip(flag, key))
print('Encrypted:', enc.hex())

random.randrange(3133731337)は3133731337未満の値の値が得られるまでrandom.getrandbits(32)を呼び出すという挙動をするため、この問題において得られる出力列は本来のmt19937の出力列の部分列となっている。出力同士のgapがわかれば線形方程式で状態を復元するいつものテクが使えるため、gapを求めることを考えたい。

untemperした(skipされていない)出力列aを考える。twistの遷移式から、以下の条件がすべてのiについて満たされる。

x = (a[i] & 0x80000000) ^ (a[i + 1] & 0x7fffffff)
assert a[i + 624] == a[i + 397] ^ (x >> 1) ^ ((x & 1) * 0x9908b0df)

a[i]は1bitしか使われないので無視すると、ランダムな値に対してこの式が満たされる確率は1/2311/2^{31}と低いので、skipされている出力列においてこの式を満たす組(i,j,k)(i,j,k)が見つかれば、高い確率でそれらの間のgapは396,227396,227とわかる。これを重み付きUnionFindで管理すると、gapが既知のindexがいくつかのグループに分けられる。
ただし、これだけでは得られるgapの情報が十分でない。ここから更にグループをまとめていく必要がある。i番目の値とj番目の値の間のgapがj-iであればその間にskipされた値は存在しないことがわかるので、間のgapはすべて1とわかる。これを使うと、サイズ2000程度のグループが2つ得られる。
適当に計算するとこの2つのグループの間のgapは適当に計算すると9だとわかる。これをもとに、あるグループが他と重複せずにはまるgapが一意に決まるならマージみたいなことを繰り返すとサイズ4000程度のグループが構成できる。これだけ集まれば十分で、あとは線形方程式で状態を復元する。
ソルバは以下の通り。

import random
import hashlib
import sys
from tqdm import trange
 
sys.path.append("/home/yu212/PycharmProjects/ctf/tools/lib")
import random_util
from linear import *
 
class WeightedUnionFind:
    def __init__(self, n):
        self.par = [i for i in range(n+1)]
        self.rank = [0] * (n+1)
        self.weight = [0] * (n+1)
    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            y = self.find(self.par[x])
            self.weight[x] += self.weight[self.par[x]]
            self.par[x] = y
            return y
    def union(self, x, y, w):
        rx = self.find(x)
        ry = self.find(y)
        if rx == ry:
            return
        if self.rank[rx] < self.rank[ry]:
            self.par[rx] = ry
            self.weight[rx] = w - self.weight[x] + self.weight[y]
        else:
            self.par[ry] = rx
            self.weight[ry] = -w - self.weight[y] + self.weight[x]
            if self.rank[rx] == self.rank[ry]:
                self.rank[rx] += 1
    def same(self, x, y):
        return self.find(x) == self.find(y)
    def diff(self, x, y):
        return self.weight[x] - self.weight[y]
 
def temper(x):
    x = x ^ (x >> 11)
    x = x ^ ((x << 7) & 0x9d2c5680)
    x = x ^ ((x << 15) & 0xefc60000)
    x = x ^ (x >> 18)
    return x & 0xffffffff
 
class MT19937:
    def __init__(self):
        self.state = LinearFunc.init_array(32, 624)
        self.index = 0
 
    def genrand(self):
        if self.index == 0:
            twist_sym(self.state)
        val = temper(self.state[self.index])
        self.index = (self.index + 1) % 624
        return val
 
def twist_sym(state):
    for i in range(0, 624):
        x = (state[i] & 0x80000000) ^ (state[(i + 1) % 624] & 0x7fffffff)
        state[i] = state[(i + 397) % 624] ^ (x >> 1) ^ (x & 1).select(0, 0x9908b0df)
 
 
with open("output.1.txt", "r") as f:
    lines = f.readlines()
    a = list(map(int, lines[:0x1337]))
    flag_enc = bytes.fromhex(lines[0x1337].removeprefix("Encrypted: "))
 
b = [random_util.untemper(v) for v in a]
 
bmp = {v: i for i, v in enumerate(b)}
c = []
for i in trange(0x1337):
    for j in range(397):
        if i+j >= 0x1337:
            continue
        x = b[i] & 0x7fffffff
        y = x | 0x80000000
        x = b[i+j] ^ (x >> 1) ^ ((x & 1) * 0x9908b0df)
        y = b[i+j] ^ (y >> 1) ^ ((y & 1) * 0x9908b0df)
        k = max(bmp.get(x, -1), bmp.get(y, -1))
        if k != -1 and i>0 and b[i-1]>>31 == (x in bmp):
            c.append(-1)
    c.append(b[i])
b = c
bmp = {v: i for i, v in enumerate(b)}
wuf = WeightedUnionFind(len(b))
lj = 0
lk = 0
for i in trange(len(b)):
    for j in range(lj+1, i+397):
        if j >= len(b) or b[i] == -1 or b[j] == -1:
            continue
        x = b[i] & 0x7fffffff
        y = x | 0x80000000
        x = b[j] ^ (x >> 1) ^ ((x & 1) * 0x9908b0df)
        y = b[j] ^ (y >> 1) ^ ((y & 1) * 0x9908b0df)
        k = max(bmp.get(x, -1), bmp.get(y, -1))
        if k != -1:
            assert lk < k
            wuf.union(i, j, 396)
            wuf.union(i, k, 623)
            lj = j
            lk = k
            break
 
def calc_group():
    gr = {}
    for i in trange(len(b)):
        if wuf.find(i) not in gr:
            gr[wuf.find(i)] = []
        gr[wuf.find(i)].append(i)
    return gr
 
gr = calc_group()
for _ in range(2):
    for g in gr.values():
        for i in g:
            for j in g:
                if not wuf.same(i, j) or i >= j:
                    continue
                if wuf.diff(i, j) == j-i:
                    for k in range(i+1, j):
                        wuf.union(i, k, k-i)
    gr = calc_group()
 
large_grs = [v for k,v in gr.items() if len(v) > 100]
grs = {v: wuf.diff(large_grs[0][0], v) for v in large_grs[0]} | {v: 9 + wuf.diff(large_grs[1][0], v) for v in large_grs[1]}
 
upd = True
while upd:
    upd = False
    for k, v in gr.items():
        if v[0] in grs:
            continue
        lst = max(kk for kk, vv in grs.items() if kk < v[0])
        nxt = min(kk for kk, vv in grs.items() if kk > v[0])
        val_set = set(grs.values())
        d_cand = []
        for d in range(grs[lst]+v[0]-lst, grs[nxt]-(nxt-v[0])+1):
            if not val_set & {d + wuf.diff(v[0], vv) for vv in v}:
                d_cand.append(d)
        if len(d_cand) == 1:
            grs |= {vv: d_cand[0] + wuf.diff(v[0], vv) for vv in v}
            upd = True
 
st = [-1] * 0x1337 * 2
for k, v in grs.items():
    st[v] = b[k]
 
sol = LinearSolver()
mt = MT19937()
sol.append(mt.state[0][:31])
for i in trange(len(st)):
    if st[i] == -1:
        mt.genrand()
        continue
    sol.append((random_util.temper(st[i]) ^ mt.genrand())[:])
    nsol = sol.num_solutions(32*624)
    if nsol == 1:
        break
 
res = sol.solve()
res = LinearFunc.reshape(res, [32]*624)
random.setstate((3, tuple(res + [624]), None))
aaa = [random.randrange(3133731337) for i in range(0x1337)]
assert aaa == a
 
sha512 = hashlib.sha512()
for _ in range(1000):
    rnd = random.getrandbits(32)
    sha512.update(str(rnd).encode("ascii"))
key = sha512.digest()
print(bytes(a ^ b for a, b in zip(flag_enc, key)))