# Loosely based on the public domain code at # http://ed25519.cr.yp.to/software.html import hashlib def H(m): return hashlib.sha512(m).digest() # Base field Z_p p = 2**255 - 19 def modp_inv (x): return pow (x, p-2, p) # Curve constant d = -121665 * modp_inv(121666) % p # Group order q = 2**252 + 27742317777372353535851937790883648493 def H_modq(s): return int.from_bytes(H(s), "little") % q # Points are represented as tuples (X, Y, Z, T) of extended coordinates, # with x = X/Z, y = Y/Z, x*y = T/Z def eddsa_add (P, Q): A = (P[1]-P[0])*(Q[1]-Q[0]) % p B = (P[1]+P[0])*(Q[1]+Q[0]) % p C = 2 * P[3] * Q[3] * d % p D = 2 * P[2] * Q[2] % p E = B-A F = D-C G = D+C H = B+A return (E*F, G*H, F*G, E*H) # Computes Q = s * Q def eddsa_mul (s, P): Q = (0, 1, 1, 0) # Neutral element while s > 0: # Is there any bit-set predicate? if s & 1: Q = eddsa_add (Q, P) P = eddsa_add (P, P) s >>= 1 return Q # Square root of -1 modp_sqrt_m1 = pow (2, (p-1) // 4, p) # Compute corresponding x coordinate, with low bit corresponding to sign, # or return None on failure def eddsa_root (y, sign): x2 = (y*y-1) * modp_inv (d*y*y+1) if x2 == 0: if sign: return None return 0 # Compute square root of x2 x = pow (x2, (p+3) // 8, p) if (x*x - x2) % p != 0: x = x * modp_sqrt_m1 % p; if (x*x - x2) % p != 0: return None if (x & 1) != sign: x = p - x; return x; # Base point def eddsa_base (): y = 4 * modp_inv (5) % p x = eddsa_root (y, 0) return (x, y, 1, x * y % p) B = eddsa_base () def eddsa_decompress (s): if len(s) != 32: raise Exception("Invalid input length for decompression") y = int.from_bytes (s, "little"); sign = y >> 255 y &= (1<<255) - 1 x = eddsa_root (y, sign) if x == None: return None return (x, y, 1, x*y % p) def eddsa_compress (P): zinv = modp_inv (P[2]) x = P[0] * zinv % p y = P[1] * zinv % p return int.to_bytes (y | ((x & 1) << 255), 32, "little") def eddsa_subkeys (sk): if len(sk) != 32: raise Exception("Bad size of private key") h = H(sk) a = int.from_bytes (h[:32], "little") a &= (1 << 254) - 8; a |= (1 << 254) return (a, h[32:]) def eddsa_public (sk): (a, dummy) = eddsa_subkeys (sk) return eddsa_compress (eddsa_mul (a, B)) def eddsa_sign (sk, msg): a, prefix = eddsa_subkeys (sk) A = eddsa_compress (eddsa_mul (a, B)) r = H_modq (prefix + msg) R = eddsa_mul (r, B) Rs = eddsa_compress (R) h = H_modq (Rs + A + msg) s = (r + h * a) % q return Rs + int.to_bytes (s, 32, "little") def eddsa_equal (P, Q): # x1 / z1 == x2 / z2 <==> x1 * z2 == x2 * z1 if (P[0] * Q[2] - Q[0] * P[2]) % p != 0: return False if (P[1] * Q[2] - Q[1] * P[2]) % p != 0: return False return True def eddsa_verify (pk, msg, signature): if len(pk) != 32: raise Exception("Bad public-key length") if len(signature) != 64: Exception("Bad signature length") A = eddsa_decompress (pk) if not A: return False Rs = signature[:32] R = eddsa_decompress (Rs) if not R: return False s = int.from_bytes(signature[32:], "little") h = H_modq(Rs + pk + msg) sB = eddsa_mul (s, B) hA = eddsa_mul (h, A) return eddsa_equal (sB, eddsa_add (R, hA)) if __name__ == "__main__": import sys import binascii def eddsa_valid (P): zinv = modp_inv (P[2]) x = P[0] * zinv % p y = P[1] * zinv % p assert (x*y - P[3]*zinv) % p == 0 return (-x*x + y*y - 1 - d*x*x*y*y) % p == 0 assert eddsa_valid (B) Z = (0, 1, 1, 0) assert eddsa_valid (Z) assert eddsa_equal (Z, eddsa_add (Z, Z)) assert eddsa_equal (B, eddsa_add (Z, B)) assert eddsa_equal (Z, eddsa_mul (0, B)) assert eddsa_equal (B, eddsa_mul (1, B)) assert eddsa_equal (eddsa_add (B, B), eddsa_mul (2, B)) for i in range(0,100): assert eddsa_valid (eddsa_mul (i, B)) assert eddsa_equal (Z, eddsa_mul (q, B)) def munge_string (s, pos, change): return s[:pos] + int.to_bytes(s[pos] ^ change, 1, "little") + s[pos+1:] # Read a file in the format of # http://ed25519.cr.yp.to/python/sign.input lineno = 0 while True: line = sys.stdin.readline() if not line: break lineno = lineno + 1 print (lineno) fields = line.split(":") sk = (binascii.unhexlify(fields[0]))[:32] pk = binascii.unhexlify(fields[1]) msg = binascii.unhexlify(fields[2]) signature = binascii.unhexlify(fields[3])[:64] assert pk == eddsa_public (sk) assert signature == eddsa_sign (sk, msg) assert eddsa_verify (pk, msg, signature) if len (msg) == 0: bad_msg = b"x" else: bad_msg = munge_string (msg, len(msg) // 3, 4) assert not eddsa_verify (pk, bad_msg, signature) bad_signature = munge_string (signature, 20, 8) assert not eddsa_verify (pk, msg, bad_signature) bad_signature = munge_string (signature, 40, 16) assert not eddsa_verify (pk, msg, bad_signature)