By rls1004 | May 11, 2016
rabit (Plaid CTF 2016 Quals, 175pt) write-up
블로그 글의 지분율을 보니 막내선원 지분율이 낭낭하네요. ㅇ<-<
rabit.py, util.py, remote server 가 주어졌네요, 이번 CTF에서 암호 문제를 처음 풀어봤는데 암알못인 저와 함께 이 문제를 풀어봅시다!
solve
평범하게 소스 분석부터!
remote server 접속 시 실행되는 코드입니다.
제일 먼저 proof_of_work 함수를 거칩니다.
proof_of_work 함수에서는 숫자와 문자 중 랜덤하게 10글자를 뽑고 클라이언트에게 전송해줍니다.
그 후 클라이언트는 최대 15글자를 입력할 수 있는데, 서버가 전송해준 10글자로 시작하지 않거나 sha1함수로 암호화한 결과가 0xffffff으로 끝나지 않는다면 프로그램이 종료됩니다.
“(랜덤 10글자)+(입력 5글자 이하)”의 sha1 해쉬값 마지막 3 bytes가 0xffffff 여야 한다.
해쉬값을 맞춰주는 문제는 다음과 같이 부루트 포싱을 이용해서 해결할 수 있습니다. python 은 정말 편리해요 :p
def ff_hash(prefix):
letters = string.printable
for c in itertools.product(letters, repeat=5):
response = "".join(c)
if sha1(prefix+response).digest()[-3:] != "\xff"*3:
continue
return prefix+response
해쉬값을 맞춰주면 N과 enc_flag 값이 출력됩니다.
util.py 의 genKey 함수에 의하면 N은 소수인 p와 q의 곱으로 만들어진 수입니다.
enc_flag는 FLAG 문자열 뒤에 패딩 문자(padchar)를 채우고 long 타입의 데이터로 변환한 뒤 몇을 더한 값을 encrypt한 숫자입니다.
원본의 FLAG 값을 알아내려면 encrypt 과정을 거꾸로 하면 될텐데,
encrypt 함수는 첫 번째 인자를 제곱하고 두 번째 인자로 나눈 나머지 값입니다.
그 후엔 alarm 시그널이 발생할 때까지 무한루프를 돌며 사용자가 새로 입력한 값을 decrypt 하여 2로 나눈 나머지 값(lsb)을 보내줍니다.
decrypt 함수를 더 살펴봅시다.
encrypt 함수로 padded를 암호화한 값(enc_flag)을 decrypt 함수에 넣으면 다시 padded가 나옵니다. decrypt 함수는 복잡하게 생겼지만 이 관계에 의하면 제곱근을 구해주는 기능을 하는 것 같네요~
그리고 assert로 단정된 “decrypt(2, p, q) != None”을 통해 “GCD(w, p*q) == 1” 임을 알 수 있고 따라서 p와 q는 2가 아닌 소수입니다.
그 외의 if문은 legendreSymbol을 검색하여 그 기능을 알 수 있었습니다. (친절한 함수 이름 감사합니다) legendreSymbol 함수는 “르장드르 기호”라고 하는데 어떤 수가 제곱잉여인지 아닌지를 나타냅니다.
제곱잉여에 대한 예시를 들어보면, 1² = 1, 2² = 4, 3² = 9, 4² = 16, … 일 때 각각의 수를 7로 나눈 나머지(mod 7)는 1, 4, 2, 4, … 입니다. 이때, 1, 2, 4는 7의 제곱잉여입니다.
decrypt 함수에서의 legendreSymbol은 사용자가 입력한 숫자가 p와 q에 대한 제곱잉여인지 판단합니다. 이게 어떤 수를 의미하는게 있는건지 저는 잘 모르겠네요 :(
아무튼! decrypt 함수는 아래와 같은 관계가 성립됩니다.
M = 평문, C = 암호문
encrypt(M, N) = C
decrypt(C, p, q) = M mod N ( ∵ p*q = N )
decrypt(4, p, q) = 2 mod N ( ∵ decrypt(…) < N )
decrypt(4C, p, q) = 2M mod N
무한루프를 돌며 값을 입력 했을 때 우리가 알 수 있는 건 2로 나눈 나머지인데, decrypt 한 결과가 홀수 일 때만 1을 출력합니다.
우리는 N이 홀수라는 것을 알고 있는데 어떤 경우에 decrypt 결과가 홀수가 될까요?
M = 홀수, N = 홀수, M < 2N
M mod N = 홀수 ( M < N )
M mod N = 짝수 ( M > N )
M = 짝수, N = 홀수, M < 2N
M mod N = 홀수 ( M > N )
M mod N = 짝수 ( M < N )
몇 가지 숫자를 대입해보면 위와 같은 규칙을 금방 찾아낼 수 있습니다.
여기에 decrypt 함수를 적용해봅시다.
M = 평문(FLAG), C = 암호문(enc_flag), N = 홀수
decrypt(4C, p, q) = 2M mod N = 홀수 ( 2M > N, M > N/2 )
decrypt(4C, p, q) = 2M mod N = 짝수 ( 2M < N, M < N/2 )
decrypt(9C, p, q) = 3M mod N = 홀수 ( ??? )
decrypt(16C, p, q) = 4M mod N = 홀수 ( 4M > N, M > N/4 )
decrypt(16C, p, q) = 4M mod N = 짝수 ( 4M < N, M < N/4 )
이럴 수가, lsb 값을 가지고 M의 범위를 알 수 있습니다!
2M과 4M의 경우는 짝수임을 분명히 알 수 있어서 짝수를 홀수로 나눈 나머지의 규칙을 적용할 수 있지만, 3M은 홀수인지 짝수인지 분명하지가 않아 M과 N의 관계를 정확히 알 수 없습니다.
짝수의 제곱 수를 입력하여 M의 범위를 줄여볼까요?
(alarm 시그널이 발생하기 때문에 세밀한 범위가 나올 때까지 확인할 순 없었습니다. M에는 padchar가 붙어있으니 start < M < end 에서 start와 end의 앞의 몇 자리가 같을 때까지만 반복하게 했습니다.)
Exploit
from Crypto.Util.number import bytes_to_long, long_to_bytes
from hashlib import sha1
import itertools
import string
from socket import *
from struct import *
p = lambda x : pack("<L", x)
up = lambda x : unpack("<L", x)
host = 'rabit.pwning.xxx'
port = 7763
sock = socket(AF_INET, SOCK_STREAM, 0)
sock.connect((host, port))
def ff_hash(prefix):
letters = string.printable
for c in itertools.product(letters, repeat=5):
response = "".join(c)
if sha1(prefix+response).digest()[-3:] != "\xff"*3:
continue
return prefix+response
def until(s, string):
data = ''
while string not in data:
data += s.recv(1)
return data
until(sock, "with ")
prefix = until(sock, ",")
prefix = prefix[:-1]
print "[*] prefixx : " + prefix
sendmsg = ff_hash(prefix)
print sendmsg
sock.send(sendmsg+"\n")
N = 81546073902331759271984999004451939555402085006705656828495536906802924215055062358675944026785619015267809774867163668490714884157533291262435378747443005227619394842923633601610550982321457446416213545088054898767148483676379966942027388615616321652290989027944696127478611206798587697949222663092494873481
enc_flag = 16155172062598073107968676378352115117161436172814227581212799030353856989153650114500204987192715640325805773228721292633844470727274927681444727510153616642152298025005171599963912929571282929138074246451372957668797897908285264033088572552509959195673435645475880129067211859038705979011490574216118690919
start = 0
end = N
i = 1
while str(end)[:100] != str(start)[:100]:
i *= 2
M = i**2
until(sock, "Give a ciphertext: ")
sock.send(str(M*enc_flag)+"\n")
until(sock, "lsb is ")
lsb = int(until(sock, "\n")[:-1])
if lsb == 1:
start = (start+end)/2
else:
end = (start+end)/2
print "[*] FLAG for you :-)"
print repr(long_to_bytes(start))