문제

Haven't you ever thought that GCM mode is overcomplicated and there must be a simpler way to achieve Authenticated Encryption? Here it is!

Server: aes-128-tsb.hackable.software 1337

server.py


분석

server.py의 내용은 다음과 같다.

#!/usr/bin/env python2
import SocketServer
import socket
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from struct import pack, unpack

from secret import AES_KEY, FLAG

class CryptoError(Exception):
    pass

def split_by(data, step):
    return [data[i : i+step] for i in xrange(0, len(data), step)]

def xor(a, b):
    assert len(a) == len(b)
    return ''.join([chr(ord(ai)^ord(bi)) for ai, bi in zip(a,b)])

def pad(msg):
    byte = 16 - len(msg) % 16
    return msg + chr(byte) * byte

def unpad(msg):
    if not msg:
        return ''
    return msg[:-ord(msg[-1])]

def tsb_encrypt(aes, msg):
    msg = pad(msg)
    iv = get_random_bytes(16)
    prev_pt = iv
    prev_ct = iv
    ct = ''
    for block in split_by(msg, 16) + [iv]:
        ct_block = xor(block, prev_pt)
        ct_block = aes.encrypt(ct_block)
        ct_block = xor(ct_block, prev_ct)
        ct += ct_block
        prev_pt = block
        prev_ct = ct_block
    return iv + ct

def tsb_decrypt(aes, msg):
    iv, msg = msg[:16], msg[16:]
    prev_pt = iv
    prev_ct = iv
    pt = ''
    for block in split_by(msg, 16):
        pt_block = xor(block, prev_ct)
        pt_block = aes.decrypt(pt_block)
        pt_block = xor(pt_block, prev_pt)
        pt += pt_block
        prev_pt = pt_block
        prev_ct = block
    pt, mac = pt[:-16], pt[-16:]
    if mac != iv:
        raise CryptoError()
    return unpad(pt)

def send_binary(s, msg):
    s.sendall(pack('<I', len(msg)))
    s.sendall(msg)

def send_enc(s, aes, msg):
    send_binary(s, tsb_encrypt(aes, msg))

def recv_exact(s, length):
    buf = ''
    while length %gt; 0:
        data = s.recv(length)
        if data == '':
            raise EOFError()
        buf += data
        length -= len(data)
    return buf

def recv_binary(s):
    size = recv_exact(s, 4)
    size = unpack('<I', size)[0]
    return recv_exact(s, size)

def recv_enc(s, aes):
    data = recv_binary(s)
    return tsb_decrypt(aes, data)

def main(s):
    aes = AES.new(AES_KEY, AES.MODE_ECB)
    try:
        while True:
            a = recv_binary(s)
            b = recv_enc(s, aes)
            if a == b:
                if a == 'gimme_flag':
                    send_enc(s, aes, FLAG)
                else:
                    # Invalid request, send some random garbage instead of the
                    # flag :)
                    send_enc(s, aes, get_random_bytes(len(FLAG)))
            else:
                send_binary(s, 'Looks like you don\'t know the secret key? Too bad.')
    except (CryptoError, EOFError):
        pass

class TaskHandler(SocketServer.BaseRequestHandler):
    def handle(self):
        main(self.request)

if __name__ == '__main__':
    SocketServer.ThreadingTCPServer.allow_reuse_address = True
    server = SocketServer.ThreadingTCPServer(('0.0.0.0', 1337), TaskHandler)
    server.serve_forever()

 

이 문제의 서버는 입력을 두번받고 출력을 한번해주는 방식으로 작동한다.

다만 모든 입출력은 정해진 형식으로 주고받는데

우선 최초의 4바이트는 전달될 메세지(평문 or 암호문)의 길이를 int packing한 형태이고

그 뒤로 해당 길이만큼의 메세지가 전달된다. (send_binary, recv_binary 참고)

그래서 server.py가 제공하는 함수를 모두 포함하고 있는 client.py를 만들고

보내고 싶은 메세지 두개를 전달인자로 넣어주면 서버의 출력을 반환해주는 run 함수를 만들었다.

def run(msg1, msg2, choice='b'):
	if choice=='':
		choice = 'b'
	send_binary(r, msg1)
	send_binary(r, msg2)
	if choice == 'b':
		return recv_binary(r)
	elif choice == 'd':
		return recv_enc(r)

 

이 문제에서 사용하는 블럭암호 방식은 다음과 같다.

앞으로 이 방식을 TSB 암호화 방식이라 하자.

 

서버의 메인함수에서는 두개의 입력을 받는다. (msg1, msg2)

그리고 변수 a 에는 msg1을 그대로 넣고

변수 b 에는 msg2를 TSB로 복호화한 값을 넣는다.

그리고 a와 b가 같은지 비교해서 다르면 우리를 놀리는 문자열을 출력한다.

a와 b가 같다면 flag의 길이만큼 랜덤 문자열을 생성한 뒤 TSB로 암호화한 값을 출력한다.

만약 a와 b가 같은데 a가 'gimme_flag'와 같다면 flag를 TSB로 암호화한 값을 출력한다.

 

이것저것 실험해보는 중 재미있는 점을 발견했다.

c1 = get_random_bytes(16)

c2 = get_random_bytes(16)

run('', c1 + c2 + c1)

위와 같이 실행하면 우리를 놀리는 문자열이 아닌 읽을 수 없는 문자열이 반환된다는 것이다!

+) 두번째 입력이 ABA형태인 이유는 mac검사를 통과하기 위함이다.

+) valid한 암호문은 ABCDE.....A의 형태를 가진다.

 

잘 생각해보면 아마 서버의 a변수와 b변수가 같아져서 랜덤 문자열의 암호문이 반환된 것 같다!

반환되는 문자열의 길이도 조사해봤는데 항상 96이다.

참고로, 우리를 놀리는 문자열의 길이는 50이다.

 

그리고 이 코드를 반복적으로 호출하다보면 가끔씩은 길이가 50인 문자열도 반환된다.

왜 그런지 알아보기 위해 복호화과정을 자세히 살펴보니 unpad 함수가 문제였다.

이 암호화방식의 패딩방식은 length 패딩이다. block size에서 부족한 길이 값을 패딩바이트로 넣는 것이다.

(예: msg가 12바이트일때 패딩결과 : msg + '\x04' * 4

근데 unpad를 하는 방식이 깔끔하지 못하다.

메세지의 마지막 바이트만 보고 그 값을 토대로 패딩길이를 판단하고 그 길이만큼의 문자열 꼬리를 잘라낸다.

이 때, 마지막 바이트가 만약 0x10 이상의 값이라면?

unpad 결과 반환되는 문자열은 빈 문자열이다 ('')

아하! 그래서 서버의 a와 b가 일치했던 것이다!

그럼 가끔씩 길이 50인 문자열이 반환된 것은?

복호화 결과 마지막 바이트가 0x10보다 작았던 것이다.

그럼 unpad 결과로 빈 문자열이 반환되지 않았을 것이고 a!=b가 참이 된다.

 


EXPLOIT

자 이제 이것을 이용해 볼 방법을 찾아보자. 

내가 생각해낸 문제를 푸는 기본 개념은 다음과 같다.

a로 '\x01'을 전달하고

unpad되기 전의 b의 값이 '\x01' + '?' * 14 + '\x0f'라면?

a == b가 만족된다.

*) 아래에서는 설명의 편의를 위해 (unpad되기 전의 b)를 b'라 하겠다.

 

그렇다면 c1으로 '\x01'을 넣고 c2는 계속 랜덤하게 줘서 저 조건을 맞출 것인가?

확률은 1/(256*256) 이다.

가능이야 하겠지만 꽤 오래 걸릴 것 같다.

시간을 획기적으로 단축시킬 방법이 한가지 있다.

run('', c1+c2+c1)의 결과로 길이 50인 문자열이 반환될 때를 찾는 것!

그 때는 b'의 마지막 바이트가 0xf 이하라는 뜻이다!

그 때의 c1, c2가 중요하다.

TSB복호화 과정을 잘 생각해보면

b' = DEC(c1 ^ c2) ^ c1

*) ENC()와 DEC()는 AES 한블럭의 암,복호화를 의미한다.

 

이때, c1 ^ c2 의 값은 유지한 채로 c1의 특정바이트만 조정하면 b'의 해당바이트를 바꿀 수 있다!

예를 들어 랜덤으로 찾은 c1의 마지막 바이트가 x이고 그 때의 b'의 마지막 바이트가 '\x00'인게 밝혀지면

c1의 마지막 바이트를 x' 으로 바꿨을때 b'의 마지막 바이트는 ('\x00' ^ x) ^ x' 이 될 것이다.

이 점을 이용하면 x'를 잘 조정해서 b'의 마지막 바이트를 우리가 원하는 대로 세팅해줄 수 있다.

 

그럼 b'의 마지막 바이트가 '\x00'일 때를 어떻게 찾는가?

사실 위에서 b'의 마지막 바이트가 '\x0f'이하로 만드는 방법은 설명을 했다.

이 때의 x'이 중요하다. 만약 이 때 x'이 '\xc4' 였다면

x'을 '\xc0' 부터 '\xcf' 까지 변화시켜보자.

b'의 마지막 바이트가 '\x00'이 될때는 서버 출력 문자열의 길이는 50이 아니라 96일 것이다.

이 또한 unpad 과정을 잘 생각해보면 알 수 있다.

 

a = '\x01'로 고정하고

b' = '\x01' + '?' * 14 + '\x0f'이 되는 c1을 찾는다.

c1 ^ c2를 유지한다면 c1의 첫바이트만 '\x00' 에서 '\xff'까지 쭉 대입해줬을 때 그 중 한 케이스가 우리가 원하는 b'을 만들어 줄 것이다.

우리가 원하는 b'을 찾았을 때 서버의 출력의 길이는 96이고 아닐 때는 50이라는 점을 이용했다.

 

그 다음 과정은 이렇다.

이제

a = '\x01\x01'로 고정하고

b = '\x01\x01' + '?' * 13 + '\x0e'이 되는 c1을 찾자.

이 때 c1의 첫바이트는 전 단계에서 찾은 값을 이용하고 두번째 바이트만 바꿔가면서 테스트해주면 된다.

 

이런식으로 반복하다보면

a = '\x01' * 16일때

b = '\x01' * 15 + '?' * 0 + '\x01'이 되는 c1까지 찾아낼 수 있을 것이다.

이 말은 ENC(c1 ^ c2) = c1 ^ b 라는 것이다.

평문-암호문 쌍 한개를 찾아낸 것이다!!

이 때

xx = c1 ^ c2

yy = c1 ^ b 라 하자.

이제

new_c1 = yy ^ pad('gimme_flag')

new_c2 = xx ^ yy ^ pad('gimme_flag') 이라 하고

run('gimme_flag', new_c1 + new_c2 + new_c1) 을 실행하면

TSB로 암호화된 flag를 얻을 수 있다!

 

이제 이를 복호화해주는 작업만 남았다.

한 블럭씩 복호화해주는 코드는 다음과 같다.

이때 c_1과 c_2에는 TSB 암호화 된 flag을 블럭단위로 잘라서 차례로 넣어준다.

즉, 처음에는 c_1 = c[0]; c_2 = c[1]이고

두번째에는 c_1 = c[1]; c_2 = c[2]

...

와 같은 방식이다.

코드에 대한 상세한 설명은 주석으로 달아놓았다.

    # c는 TSB 암호화 된 flag를 블럭단위로 나누어서 저장해놓은 list
    c_1 = c[0] # c[0] : iv
    c_2 = c[1]
    x = xor(c_1, c_2)

    # b'의 마지막 바이트를 0x0f 이하로 만드는 c_1 값 찾기
    i = 0x00
    while True:
        c_1 = get_random_bytes(16)
        c_2 = xor(x, c_1)
        # print hex(i)
        ret = run('', c_1 + c_2 + c_1)
        if len(ret) == 50:
            # print c_1.encode('hex')
            break
        i += 1
    c_1_found = c_1

    # b'의 마지막 바이트를 0x00 으로 만드는 c_1 값 찾기
    b_last = c_1[-1]
    b_last &= 0xf0
    while True:
        c_1 = c_1_found[:-1] + chr(b_last)
        c_2 = xor(x, c_1)
        # print hex(b_last)
        ret = run('', c_1 + c_2 + c_1)
        if len(ret) == 96:
            # print c_1.encode('hex')
            break
        b_last += 1
    c_1_found2 = c_1

    # b' == '\x01' * 16인 c_1 값 찾고 평문 한 블럭 구하기
    # p = 'DrgnS{Thank_god_no_one_deployed_this_on_producti'
    # p_prev = 'this_on_producti'
    p_prev = c[0]   # 처음에는 iv 값이다.
    c_first = ''
    i = 1
    while True:
        b_test = 0x00
        c_last = c_1_found2[-1] ^ (0x10 - i)    # c_1_found2[-1] == b_last
        if i == 16:
            break
        while True:
            c_1 = c_first
            c_1 += chr(b_test) + c_1_found[i:-1] + chr(c_last)
            c_2 = xor(x, c_1)
            #print hex(b_test)
            ret = run('\x01' * i, c_1 + c_2 + c_1)
            if len(ret) == 96:
                    #print c_1.encode('hex')
                    c_first += chr(b_test)
                    i += 1
                    break
            b_test += 1

    d_x = xor('\x01'*16, c_1)
    r = xor(d_x, p_prev)
    print 'flag : ', p + r

 

 


FL4G

처음 한 블럭만 실행해본 모습, flag가 나오기 시작함을 알 수 있다

 

마지막 블럭에 대해 실행한 모습, 전체 flag를 구해낼 수 있었다.

프로그램을 다 만들고 4개의 블럭에 대해 모두 수행해주는데 걸린시간은 약 40분? 사실 잘 기억이 안난다..

+ Recent posts