Commit 3550b54a by Scott Duckworth

make parse_key() return a PublicKey instance

parent 80071eb9
...@@ -29,28 +29,7 @@ ...@@ -29,28 +29,7 @@
from django.db import models from django.db import models
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django_sshkey.util import SSHKeyFormatError, key_parse from django_sshkey.util import SSHKeyFormatError, pubkey_parse
def wrap(text, width, end=None):
n = 0
t = ''
if end is None:
while n < len(text):
m = n + width
t += text[n:m]
if len(text) <= m:
return t
t += '\n'
n = m
else:
while n < len(text):
m = n + width
if len(text) <= m:
return t + text[n:m]
m -= len(end)
t += text[n:m] + end + '\n'
n = m
return t
class UserKey(models.Model): class UserKey(models.Model):
user = models.ForeignKey(User, db_index=True) user = models.ForeignKey(User, db_index=True)
...@@ -75,18 +54,15 @@ class UserKey(models.Model): ...@@ -75,18 +54,15 @@ class UserKey(models.Model):
def clean(self): def clean(self):
try: try:
info = key_parse(self.key) pubkey = pubkey_parse(self.key)
self.fingerprint = info.fingerprint
if info.comment:
self.key = "%s %s %s" % (info.type.decode(), info.b64key.decode(), info.comment)
else:
self.key = "%s %s" % (info.type.decode(), info.b64key.decode())
except SSHKeyFormatError as e: except SSHKeyFormatError as e:
raise ValidationError(str(e)) raise ValidationError(str(e))
self.key = pubkey.format_openssh()
self.fingerprint = pubkey.fingerprint()
if not self.name: if not self.name:
if not info.comment: if not pubkey.comment:
raise ValidationError('Name or key comment required') raise ValidationError('Name or key comment required')
self.name = info.comment self.name = pubkey.comment
def validate_unique(self, exclude=None): def validate_unique(self, exclude=None):
if self.pk is None: if self.pk is None:
...@@ -109,14 +85,8 @@ class UserKey(models.Model): ...@@ -109,14 +85,8 @@ class UserKey(models.Model):
pass pass
def export_openssh(self): def export_openssh(self):
return self.key.encode('utf-8') return self.key
def export_rfc4716(self): def export_rfc4716(self):
info = key_parse(self.key) pubkey = pubkey_parse(self.key)
out = b'---- BEGIN SSH2 PUBLIC KEY ----\n' return pubkey.format_rfc4716()
if info.comment:
comment = 'Comment: "%s"' % info.comment
out += wrap(comment, 72, '\\').encode('ascii') + b'\n'
out += wrap(info.b64key, 72).encode('ascii') + b'\n'
out += b'---- END SSH2 PUBLIC KEY ----'
return out
...@@ -26,11 +26,31 @@ ...@@ -26,11 +26,31 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections import namedtuple import base64
import struct
SSHKEY_LOOKUP_URL_DEFAULT = 'http://localhost:8000/sshkey/lookup' SSHKEY_LOOKUP_URL_DEFAULT = 'http://localhost:8000/sshkey/lookup'
KeyInfo = namedtuple('KeyInfo', 'type b64key comment fingerprint') def wrap(text, width, wrap_end=None):
n = 0
t = ''
if wrap_end is None:
while n < len(text):
m = n + width
t += text[n:m]
if len(text) <= m:
return t
t += '\n'
n = m
else:
while n < len(text):
m = n + width
if len(text) <= m:
return t + text[n:m]
m -= len(wrap_end)
t += text[n:m] + wrap_end + '\n'
n = m
return t
class SSHKeyFormatError(Exception): class SSHKeyFormatError(Exception):
def __init__(self, text): def __init__(self, text):
...@@ -39,63 +59,86 @@ class SSHKeyFormatError(Exception): ...@@ -39,63 +59,86 @@ class SSHKeyFormatError(Exception):
def __str__(self): def __str__(self):
return "Unrecognized public key format" return "Unrecognized public key format"
def key_parse(text): class PublicKey(object):
import base64 def __init__(self, b64key, comment=None):
import hashlib self.b64key = b64key
import struct self.comment = comment
lines = text.splitlines() self.keydata = base64.b64decode(b64key.encode('ascii'))
n = struct.unpack('>I', self.keydata[:4])
self.algorithm = self.keydata[4:4+n[0]]
# OpenSSH public key def fingerprint(self):
if len(lines) == 1 and text.startswith(b'ssh-'): import hashlib
fp = hashlib.md5(self.keydata).hexdigest()
return ':'.join(a+b for a,b in zip(fp[::2], fp[1::2]))
def format_openssh(self):
out = self.algorithm + ' ' + self.b64key
if self.comment:
out += ' ' + self.comment
return out
def format_rfc4716(self):
out = '---- BEGIN SSH2 PUBLIC KEY ----\n'
if self.comment:
comment = 'Comment: "%s"' % self.comment
out += wrap(comment, 72, '\\') + '\n'
out += wrap(self.b64key, 72) + '\n'
out += '---- END SSH2 PUBLIC KEY ----'
return out
def pubkey_parse_openssh(text):
fields = text.split(None, 2) fields = text.split(None, 2)
if len(fields) < 2: if len(fields) < 2:
raise SSHKeyFormatError(text) raise SSHKeyFormatError(text)
type = fields[0]
b64key = fields[1]
comment = None
if len(fields) == 3:
comment = fields[2]
try: try:
key = base64.b64decode(b64key) if len(fields) == 2:
key = PublicKey(fields[1])
else:
key = PublicKey(fields[1], fields[2])
except TypeError: except TypeError:
raise SSHKeyFormatError(text) raise SSHKeyFormatError(text)
if fields[0] != key.algorithm:
raise SSHKeyFormatError(text)
return key
# SSH2 public key def pubkey_parse_rfc4716(text):
elif ( lines = text.splitlines()
lines[0] == b'---- BEGIN SSH2 PUBLIC KEY ----' if not (
and lines[-1] == b'---- END SSH2 PUBLIC KEY ----' lines[0] == '---- BEGIN SSH2 PUBLIC KEY ----'
and lines[-1] == '---- END SSH2 PUBLIC KEY ----'
): ):
b64key = b'' raise SSHKeyFormatError(text)
headers = {}
lines = lines[1:-1] lines = lines[1:-1]
b64key = ''
headers = {}
while lines: while lines:
line = lines.pop(0) line = lines.pop(0)
if b':' in line: if ':' in line:
while line[-1] == b'\\': while line[-1] == '\\':
line = line[:-1] + lines.pop(0) line = line[:-1] + lines.pop(0)
k,v = line.split(b':', 1) k,v = line.split(':', 1)
headers[k.lower().decode('ascii')] = v.lstrip().decode('utf-8') headers[k.lower()] = v.lstrip()
else: else:
b64key += line b64key += line
comment = headers.get('comment') comment = headers.get('comment')
if comment and comment[0] in ('"', "'") and comment[0] == comment[-1]: if comment and comment[0] in ('"', "'") and comment[0] == comment[-1]:
comment = comment[1:-1] comment = comment[1:-1]
try: try:
key = base64.b64decode(b64key) return PublicKey(b64key, comment)
except TypeError: except TypeError:
raise SSHKeyFormatError(text) raise SSHKeyFormatError(text)
if len(key) < 4:
raise SSHKeyFormatError(text)
n = struct.unpack('>I', key[:4])
type = key[4:4+n[0]]
# unrecognized format def pubkey_parse(text):
else: lines = text.splitlines()
raise SSHKeyFormatError(text)
if len(lines) == 1:
return pubkey_parse_openssh(text)
fp = hashlib.md5(key).hexdigest() if lines[0] == '---- BEGIN SSH2 PUBLIC KEY ----':
fp = ':'.join(a+b for a,b in zip(fp[::2], fp[1::2])) return pubkey_parse_rfc4716(text)
return KeyInfo(type, b64key, comment, fp)
raise SSHKeyFormatError(text)
def lookup_all(url): def lookup_all(url):
import urllib import urllib
...@@ -146,11 +189,11 @@ def lookup_by_fingerprint_main(): ...@@ -146,11 +189,11 @@ def lookup_by_fingerprint_main():
) )
sys.exit(1) sys.exit(1)
try: try:
info = key_parse(key) pubkey = pubkey_parse(key)
fingerprint = info.fingerprint
except SSHKeyFormatError as e: except SSHKeyFormatError as e:
sys.stderr.write("Error: " + str(e)) sys.stderr.write("Error: " + str(e))
sys.exit(1) sys.exit(1)
fingerprint = pubkey.fingerprint()
url = getenv('SSHKEY_LOOKUP_URL', SSHKEY_LOOKUP_URL_DEFAULT) url = getenv('SSHKEY_LOOKUP_URL', SSHKEY_LOOKUP_URL_DEFAULT)
for key in lookup_by_fingerprint(url, fingerprint): for key in lookup_by_fingerprint(url, fingerprint):
sys.stdout.write(key) sys.stdout.write(key)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment