Commit 9fd0bd7c by carpoon

don't have a copy of the same class twice

Move ssh.py one level up as the same code is used
parent eb2eeb12
...@@ -32,7 +32,7 @@ from base64 import decodestring ...@@ -32,7 +32,7 @@ from base64 import decodestring
from hashlib import md5 from hashlib import md5
from .ssh import PubKey from agent.ssh import PubKey
from .network import change_ip_freebsd from .network import change_ip_freebsd
from agent.BaseContext import BaseContext from agent.BaseContext import BaseContext
......
...@@ -20,7 +20,7 @@ from glob import glob ...@@ -20,7 +20,7 @@ from glob import glob
from io import StringIO from io import StringIO
from base64 import decodestring from base64 import decodestring
from hashlib import md5 from hashlib import md5
from agent.linux.ssh import PubKey from agent.ssh import PubKey
from agent.linux.network import change_ip_ubuntu, change_ip_rhel from agent.linux.network import change_ip_ubuntu, change_ip_rhel
from agent.BaseContext import BaseContext from agent.BaseContext import BaseContext
from twisted.internet import reactor from twisted.internet import reactor
......
from base64 import decodestring
from struct import unpack
import binascii
import unittest
class InvalidKeyType(Exception):
pass
class InvalidKey(Exception):
pass
class PubKey(object):
key_types = ('ssh-rsa', 'ssh-dsa', 'ssh-ecdsa')
# http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-
# validation-using-a-regular-expression
@classmethod
def validate_key(cls, key_type, key):
try:
data = decodestring(key)
except binascii.Error:
raise InvalidKey()
int_len = 4
str_len = unpack('>I', data[:int_len])[0]
if data[int_len:int_len + str_len] != key_type:
raise InvalidKey()
def __init__(self, key_type, key, comment):
if key_type not in self.key_types:
raise InvalidKeyType()
self.key_type = key_type
PubKey.validate_key(key_type, key)
self.key = key
self.comment = str(comment)
def __hash__(self):
return hash(frozenset(list(self.__dict__.items())))
def __eq__(self, other):
return self.__dict__ == other.__dict__
@classmethod
def from_str(cls, line):
key_type, key, comment = line.split()
return PubKey(key_type, key, comment)
def __unicode__(self):
return ' '.join((self.key_type, self.key, self.comment))
def __repr__(self):
return '<PubKey: %s>' % str(self)
# Unit tests
class SshTestCase(unittest.TestCase):
def setUp(self):
self.p1 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.p2 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.p3 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EC comment')
def test_invalid_key_type(self):
self.assertRaises(InvalidKeyType, PubKey, 'ssh-inv', 'x', 'comment')
def test_valid_key(self):
PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
def test_invalid_key(self):
self.assertRaises(InvalidKey, PubKey, 'ssh-rsa', 'x', 'comment')
def test_invalid_key2(self):
self.assertRaises(InvalidKey, PubKey, 'ssh-rsa',
'AAAAB3MzaC1yc2EA', 'comment')
def test_repr(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(
repr(p), '<PubKey: ssh-rsa AAAAB3NzaC1yc2EA comment>')
def test_unicode(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_from_str(self):
p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_eq(self):
self.assertEqual(self.p1, self.p2)
self.assertNotEqual(self.p1, self.p3)
def test_hash(self):
s = set()
s.add(self.p1)
s.add(self.p2)
s.add(self.p3)
self.assertEqual(len(s), 2)
if __name__ == '__main__':
unittest.main()
...@@ -18,8 +18,8 @@ class PubKey(object): ...@@ -18,8 +18,8 @@ class PubKey(object):
# http://stackoverflow.com/questions/2494450/ssh-rsa-public-key- # http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-
# validation-using-a-regular-expression # validation-using-a-regular-expression
@classmethod @staticmethod
def validate_key(cls, key_type, key): def validate_key(key_type, key):
try: try:
data = decodestring(key) data = decodestring(key)
except binascii.Error: except binascii.Error:
...@@ -45,8 +45,8 @@ class PubKey(object): ...@@ -45,8 +45,8 @@ class PubKey(object):
def __eq__(self, other): def __eq__(self, other):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
@classmethod @staticmethod
def from_str(cls, line): def from_str(line):
key_type, key, comment = line.split() key_type, key, comment = line.split()
return PubKey(key_type, key, comment) return PubKey(key_type, key, comment)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment