Commit 257edc73 by Bach Dániel

add ssh key support

parent c9e24295
...@@ -3,3 +3,5 @@ ...@@ -3,3 +3,5 @@
.ropeproject .ropeproject
*.pyc *.pyc
version.txt version.txt
*,cover
.coverage
...@@ -12,6 +12,8 @@ import fileinput ...@@ -12,6 +12,8 @@ import fileinput
import platform import platform
import sys import sys
import tarfile import tarfile
from os.path import expanduser, join
from os import mkdir
from glob import glob from glob import glob
from StringIO import StringIO from StringIO import StringIO
from base64 import decodestring from base64 import decodestring
...@@ -20,8 +22,14 @@ from datetime import datetime ...@@ -20,8 +22,14 @@ from datetime import datetime
from utils import SerialLineReceiverBase from utils import SerialLineReceiverBase
from ssh import PubKey
logger = logging.getLogger() logger = logging.getLogger()
SSH_DIR = expanduser('~cloud/.ssh')
AUTHORIZED_KEYS = join(SSH_DIR, 'authorized_keys')
fstab_template = ('sshfs#%(username)s@%(host)s:home /home/cloud/sshfs ' fstab_template = ('sshfs#%(username)s@%(host)s:home /home/cloud/sshfs '
'fuse defaults,idmap=user,reconnect,_netdev,uid=1000,' 'fuse defaults,idmap=user,reconnect,_netdev,uid=1000,'
'gid=1000,allow_other,StrictHostKeyChecking=no,' 'gid=1000,allow_other,StrictHostKeyChecking=no,'
...@@ -152,6 +160,59 @@ class Context(object): ...@@ -152,6 +160,59 @@ class Context(object):
f.write(data) f.write(data)
@staticmethod @staticmethod
def get_keys():
retval = []
try:
with open(AUTHORIZED_KEYS, 'r') as f:
for line in f.readlines():
try:
retval.append(PubKey.from_str(line))
except:
logger.exception(u'Invalid ssh key: ')
except IOError:
pass
return retval
@staticmethod
def _save_keys(keys):
print keys
try:
mkdir(SSH_DIR)
except OSError:
pass
with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys:
f.write(unicode(key) + '\n')
@staticmethod
def add_keys(keys):
if system == 'Linux':
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
if p not in new_keys:
new_keys.append(p)
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def del_keys(keys):
if system == 'Linux':
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
try:
new_keys.remove(p)
except ValueError:
pass
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def cleanup(): def cleanup():
if system == 'Linux': if system == 'Linux':
filelist = ([ filelist = ([
......
from base64 import decodestring
from struct import unpack
import binascii
class InvalidKeyType(Exception):
pass
class InvalidKey(Exception):
pass
class PubKey(object):
key_types = ('ssh-rsa', 'ssh-dsa', 'ssh-ecdsa')
# http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-
# validation-using-a-regular-expression
@classmethod
def validate_key(cls, key_type, key):
try:
data = decodestring(key)
except binascii.Error:
raise InvalidKey()
int_len = 4
str_len = unpack('>I', data[:int_len])[0]
if data[int_len:int_len + str_len] != key_type:
raise InvalidKey()
def __init__(self, key_type, key, comment):
if key_type not in self.key_types:
raise InvalidKeyType()
self.key_type = key_type
PubKey.validate_key(key_type, key)
self.key = key
self.comment = unicode(comment)
def __hash__(self):
return hash(frozenset(self.__dict__.items()))
def __eq__(self, other):
return self.__dict__ == other.__dict__
@classmethod
def from_str(cls, line):
key_type, key, comment = line.split()
return PubKey(key_type, key, comment)
def __unicode__(self):
return u' '.join((self.key_type, self.key, self.comment))
def __repr__(self):
return u'<PubKey: %s>' % unicode(self)
import unittest
class SshTestCase(unittest.TestCase):
def setUp(self):
self.p1 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.p2 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.p3 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EC comment')
def test_invalid_key_type(self):
self.assertRaises(InvalidKeyType, PubKey, 'ssh-inv', 'x', 'comment')
def test_valid_key(self):
PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
def test_invalid_key(self):
self.assertRaises(InvalidKey, PubKey, 'ssh-rsa', 'x', 'comment')
def test_invalid_key2(self):
self.assertRaises(InvalidKey, PubKey, 'ssh-rsa',
'AAAAB3MzaC1yc2EA', 'comment')
def test_repr(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(
repr(p), '<PubKey: ssh-rsa AAAAB3NzaC1yc2EA comment>')
def test_unicode(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_from_str(self):
p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_eq(self):
self.assertEqual(self.p1, self.p2)
self.assertNotEqual(self.p1, self.p3)
def test_hash(self):
s = set()
s.add(self.p1)
s.add(self.p2)
s.add(self.p3)
self.assertEqual(len(s), 2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment