Balsn CTF 2019: unpredictable
2026/01/03
問題ソースコードは以下の通り。
問題ソースコード
import sys
import os
import hashlib
import random
version = sys.version.replace('\n', ' ')
print(f'Python {version}')
random.seed(os.urandom(1337))
for i in range(0x1337):
print(random.randrange(3133731337))
# Encrypt flag
sha512 = hashlib.sha512()
for _ in range(1000):
rnd = random.getrandbits(32)
sha512.update(str(rnd).encode('ascii'))
key = sha512.digest()
with open('../flag.txt', 'rb') as f:
flag = f.read()
enc = bytes(a ^ b for a, b in zip(flag, key))
print('Encrypted:', enc.hex())random.randrange(3133731337)は3133731337未満の値の値が得られるまでrandom.getrandbits(32)を呼び出すという挙動をするため、この問題において得られる出力列は本来のmt19937の出力列の部分列となっている。出力同士のgapがわかれば線形方程式で状態を復元するいつものテクが使えるため、gapを求めることを考えたい。
untemperした(skipされていない)出力列aを考える。twistの遷移式から、以下の条件がすべてのiについて満たされる。
x = (a[i] & 0x80000000) ^ (a[i + 1] & 0x7fffffff)
assert a[i + 624] == a[i + 397] ^ (x >> 1) ^ ((x & 1) * 0x9908b0df)a[i]は1bitしか使われないので無視すると、ランダムな値に対してこの式が満たされる確率はと低いので、skipされている出力列においてこの式を満たす組が見つかれば、高い確率でそれらの間のgapはとわかる。これを重み付きUnionFindで管理すると、gapが既知のindexがいくつかのグループに分けられる。
ただし、これだけでは得られるgapの情報が十分でない。ここから更にグループをまとめていく必要がある。i番目の値とj番目の値の間のgapがj-iであればその間にskipされた値は存在しないことがわかるので、間のgapはすべて1とわかる。これを使うと、サイズ2000程度のグループが2つ得られる。
適当に計算するとこの2つのグループの間のgapは適当に計算すると9だとわかる。これをもとに、あるグループが他と重複せずにはまるgapが一意に決まるならマージみたいなことを繰り返すとサイズ4000程度のグループが構成できる。これだけ集まれば十分で、あとは線形方程式で状態を復元する。
ソルバは以下の通り。
import random
import hashlib
import sys
from tqdm import trange
sys.path.append("/home/yu212/PycharmProjects/ctf/tools/lib")
import random_util
from linear import *
class WeightedUnionFind:
def __init__(self, n):
self.par = [i for i in range(n+1)]
self.rank = [0] * (n+1)
self.weight = [0] * (n+1)
def find(self, x):
if self.par[x] == x:
return x
else:
y = self.find(self.par[x])
self.weight[x] += self.weight[self.par[x]]
self.par[x] = y
return y
def union(self, x, y, w):
rx = self.find(x)
ry = self.find(y)
if rx == ry:
return
if self.rank[rx] < self.rank[ry]:
self.par[rx] = ry
self.weight[rx] = w - self.weight[x] + self.weight[y]
else:
self.par[ry] = rx
self.weight[ry] = -w - self.weight[y] + self.weight[x]
if self.rank[rx] == self.rank[ry]:
self.rank[rx] += 1
def same(self, x, y):
return self.find(x) == self.find(y)
def diff(self, x, y):
return self.weight[x] - self.weight[y]
def temper(x):
x = x ^ (x >> 11)
x = x ^ ((x << 7) & 0x9d2c5680)
x = x ^ ((x << 15) & 0xefc60000)
x = x ^ (x >> 18)
return x & 0xffffffff
class MT19937:
def __init__(self):
self.state = LinearFunc.init_array(32, 624)
self.index = 0
def genrand(self):
if self.index == 0:
twist_sym(self.state)
val = temper(self.state[self.index])
self.index = (self.index + 1) % 624
return val
def twist_sym(state):
for i in range(0, 624):
x = (state[i] & 0x80000000) ^ (state[(i + 1) % 624] & 0x7fffffff)
state[i] = state[(i + 397) % 624] ^ (x >> 1) ^ (x & 1).select(0, 0x9908b0df)
with open("output.1.txt", "r") as f:
lines = f.readlines()
a = list(map(int, lines[:0x1337]))
flag_enc = bytes.fromhex(lines[0x1337].removeprefix("Encrypted: "))
b = [random_util.untemper(v) for v in a]
bmp = {v: i for i, v in enumerate(b)}
c = []
for i in trange(0x1337):
for j in range(397):
if i+j >= 0x1337:
continue
x = b[i] & 0x7fffffff
y = x | 0x80000000
x = b[i+j] ^ (x >> 1) ^ ((x & 1) * 0x9908b0df)
y = b[i+j] ^ (y >> 1) ^ ((y & 1) * 0x9908b0df)
k = max(bmp.get(x, -1), bmp.get(y, -1))
if k != -1 and i>0 and b[i-1]>>31 == (x in bmp):
c.append(-1)
c.append(b[i])
b = c
bmp = {v: i for i, v in enumerate(b)}
wuf = WeightedUnionFind(len(b))
lj = 0
lk = 0
for i in trange(len(b)):
for j in range(lj+1, i+397):
if j >= len(b) or b[i] == -1 or b[j] == -1:
continue
x = b[i] & 0x7fffffff
y = x | 0x80000000
x = b[j] ^ (x >> 1) ^ ((x & 1) * 0x9908b0df)
y = b[j] ^ (y >> 1) ^ ((y & 1) * 0x9908b0df)
k = max(bmp.get(x, -1), bmp.get(y, -1))
if k != -1:
assert lk < k
wuf.union(i, j, 396)
wuf.union(i, k, 623)
lj = j
lk = k
break
def calc_group():
gr = {}
for i in trange(len(b)):
if wuf.find(i) not in gr:
gr[wuf.find(i)] = []
gr[wuf.find(i)].append(i)
return gr
gr = calc_group()
for _ in range(2):
for g in gr.values():
for i in g:
for j in g:
if not wuf.same(i, j) or i >= j:
continue
if wuf.diff(i, j) == j-i:
for k in range(i+1, j):
wuf.union(i, k, k-i)
gr = calc_group()
large_grs = [v for k,v in gr.items() if len(v) > 100]
grs = {v: wuf.diff(large_grs[0][0], v) for v in large_grs[0]} | {v: 9 + wuf.diff(large_grs[1][0], v) for v in large_grs[1]}
upd = True
while upd:
upd = False
for k, v in gr.items():
if v[0] in grs:
continue
lst = max(kk for kk, vv in grs.items() if kk < v[0])
nxt = min(kk for kk, vv in grs.items() if kk > v[0])
val_set = set(grs.values())
d_cand = []
for d in range(grs[lst]+v[0]-lst, grs[nxt]-(nxt-v[0])+1):
if not val_set & {d + wuf.diff(v[0], vv) for vv in v}:
d_cand.append(d)
if len(d_cand) == 1:
grs |= {vv: d_cand[0] + wuf.diff(v[0], vv) for vv in v}
upd = True
st = [-1] * 0x1337 * 2
for k, v in grs.items():
st[v] = b[k]
sol = LinearSolver()
mt = MT19937()
sol.append(mt.state[0][:31])
for i in trange(len(st)):
if st[i] == -1:
mt.genrand()
continue
sol.append((random_util.temper(st[i]) ^ mt.genrand())[:])
nsol = sol.num_solutions(32*624)
if nsol == 1:
break
res = sol.solve()
res = LinearFunc.reshape(res, [32]*624)
random.setstate((3, tuple(res + [624]), None))
aaa = [random.randrange(3133731337) for i in range(0x1337)]
assert aaa == a
sha512 = hashlib.sha512()
for _ in range(1000):
rnd = random.getrandbits(32)
sha512.update(str(rnd).encode("ascii"))
key = sha512.digest()
print(bytes(a ^ b for a, b in zip(flag_enc, key)))