Commit 382cea4d by carpoon

migration first step

py2to3
Move all sources to agent folder to help imports
rename utils to SerialLineReceiverBase as it only contains this class
fix string and binary storage differences in python3
update requirements, add windows specific requirements
move platform specific logic to the platform specific classes
parent 6bdaacab
...@@ -5,16 +5,13 @@ import servicemanager ...@@ -5,16 +5,13 @@ import servicemanager
import socket import socket
import sys import sys
import win32event import win32event
import win32service
import win32serviceutil import win32serviceutil
import win32service
from agent import main as agent_main, reactor from agent.agent import init_serial, reactor
logger = logging.getLogger() logger = logging.getLogger()
fh = NTEventLogHandler( fh = NTEventLogHandler("CIRCLE Agent", dllname=os.path.dirname(__file__))
"CIRCLE Agent", dllname=os.path.dirname(__file__)) formatter = logging.Formatter("%(asctime)s - %(name)s [%(levelname)s] %(message)s")
formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO') level = os.environ.get('LOGLEVEL', 'INFO')
...@@ -42,7 +39,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -42,7 +39,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager.PYS_SERVICE_STARTED, servicemanager.PYS_SERVICE_STARTED,
(self._svc_name_, '')) (self._svc_name_, ''))
logger.info("%s starting", __file__) logger.info("%s starting", __file__)
agent_main() init_serial()
def main(): def main():
...@@ -68,4 +65,4 @@ if __name__ == '__main__': ...@@ -68,4 +65,4 @@ if __name__ == '__main__':
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
raise raise
except Exception: except Exception:
logger.exception("Exception:") logger.exception("Exception: %s" % str(Exception))
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from os import environ, chdir from os import environ, chdir
import platform import platform
import subprocess import subprocess
import sys import sys
try: # noqa
system = platform.system() # noqa chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
if system == "Linux" or system == "FreeBSD": # noqa except Exception: # noqa
try: # noqa pass # hope it works # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
...@@ -23,150 +18,22 @@ from twisted.internet.task import LoopingCall ...@@ -23,150 +18,22 @@ from twisted.internet.task import LoopingCall
import uptime import uptime
import logging import logging
from inspect import getargspec, isfunction from inspect import getfullargspec, isfunction
from utils import SerialLineReceiverBase
from agent.SerialLineReceiverBase import SerialLineReceiverBase
from agent.agent import init_serial, reactor
# Note: Import everything because later we need to use the BaseContext # Note: Import everything because later we need to use the BaseContext
# (relative import error. # (relative import error.
from context import BaseContext, get_context, get_serial # noqa from agent.context import BaseContext, get_context # noqa
Context = get_context() Context = get_context()
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self):
self.transport.write('\r\n')
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(),
'system': system})
def shutdown():
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
reactor.addSystemEventTrigger("before", "shutdown", shutdown)
def connectionLost(self, reason):
reactor.stop()
def connectionLost2(self, reason):
self.send_command(command='agent_stopped',
args={})
def tick(self):
logger.debug("Sending tick")
try:
self.send_status()
except Exception:
logger.exception("Twisted hide exception")
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.lc = LoopingCall(self.tick)
self.lc.start(5, now=False)
def send_status(self): if __name__ == '__main__':
import psutil logging.basicConfig()
disk_usage = dict((disk.device.replace('/', '_'), logger = logging.getLogger()
psutil.disk_usage(disk.mountpoint).percent) level = environ.get('LOGLEVEL', 'INFO')
for disk in psutil.disk_partitions())
args = {"cpu": dict(psutil.cpu_times()._asdict()),
"ram": dict(psutil.virtual_memory()._asdict()),
"swap": dict(psutil.swap_memory()._asdict()),
"uptime": {"seconds": uptime.uptime()},
"disk": disk_usage,
"user": {"count": len(psutil.users())}}
self.send_response(response='status', args=args)
logger.debug("send_status finished")
def _check_args(self, func, args):
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
(self._pretty_fun(func), type(args).__name__))
# check for unexpected keyword arguments
argspec = getargspec(func)
if argspec.keywords is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs:
raise TypeError(
"Command %s got unexpected keyword arguments: %s" % (
self._pretty_fun(func), ", ".join(unexpected_kwargs)))
mandatory_args = argspec.args
if argspec.defaults: # remove those with default value
mandatory_args = mandatory_args[0:-len(argspec.defaults)]
missing_kwargs = set(mandatory_args) - set(args)
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
def _get_command(self, command, args):
if not isinstance(command, basestring) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command)
try:
func = getattr(Context, command)
except AttributeError as e:
raise AttributeError(u'Command not found: %s (%s)' % (command, e))
if not isfunction(func):
raise AttributeError("Command refers to non-static method %s." %
self._pretty_fun(func))
self._check_args(func, args)
logger.debug("_get_command finished")
return func
@staticmethod
def _pretty_fun(fun):
try:
argspec = getargspec(fun)
args = argspec.args
if argspec.varargs:
args.append("*" + argspec.varargs)
if argspec.keywords:
args.append("**" + argspec.keywords)
return "%s(%s)" % (fun.__name__, ",".join(args))
except Exception:
return "<%s>" % type(fun).__name__
def handle_command(self, command, args):
func = self._get_command(command, args)
retval = func(**args)
self.send_response(
response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)})
def handle_response(self, response, args):
pass
def main():
# Get proper serial class and port name
(serial, port) = get_serial()
logger.info("Opening port %s", port)
# Open serial connection
serial(SerialLineReceiver(), port, reactor)
try:
from notify import register_publisher
register_publisher(reactor)
except Exception:
logger.exception("Could not register notify publisher")
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor finished.")
logger.setLevel(level)
if __name__ == '__main__': init_serial()
main()
...@@ -3,28 +3,26 @@ import json ...@@ -3,28 +3,26 @@ import json
import logging import logging
import platform import platform
logger = logging.getLogger() logger = logging.getLogger()
system = platform.system()
class SerialLineReceiverBase(LineReceiver, object): class SerialLineReceiverBase(LineReceiver, object):
MAX_LENGTH = 1024*1024*128 MAX_LENGTH = 1024*1024*128
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if system == "FreeBSD": if platform.system() == "FreeBSD":
self.delimiter = '\n' self.delimiter = b'\n'
else: else:
self.delimiter = '\r' self.delimiter = b'\r'
super(SerialLineReceiverBase, self).__init__(*args, **kwargs) super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args): def send_response(self, response, args):
self.transport.write(json.dumps({'response': response, self.transport.write((json.dumps({'response': response,
'args': args}) + '\r\n') 'args': args}) + '\r\n').encode())
def send_command(self, command, args): def send_command(self, command, args):
self.transport.write(json.dumps({'command': command, self.transport.write((json.dumps({'command': command,
'args': args}) + '\r\n') 'args': args}) + '\r\n').encode())
def handle_command(self, command, args): def handle_command(self, command, args):
raise NotImplementedError("Subclass must implement abstract method") raise NotImplementedError("Subclass must implement abstract method")
...@@ -45,12 +43,12 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -45,12 +43,12 @@ class SerialLineReceiverBase(LineReceiver, object):
logger.error('[serial] invalid json: %s (%s)' % (data, e)) logger.error('[serial] invalid json: %s (%s)' % (data, e))
return return
if command is not None and isinstance(command, unicode): if command is not None and isinstance(command, str):
logger.debug('received command: %s (%s)' % (command, args)) logger.debug('received command: %s (%s)' % (command, args))
try: try:
self.handle_command(command, args) self.handle_command(command, args)
except Exception as e: except Exception as e:
logger.exception(u'Unhandled exception: ') logger.exception(u'Unhandled exception: ')
elif response is not None and isinstance(response, unicode): elif response is not None and isinstance(response, str):
logger.debug('received reply: %s (%s)' % (response, args)) logger.debug('received reply: %s (%s)' % (response, args))
self.handle_response(response, args) self.handle_response(response, args)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os import environ
import platform
from twisted.internet import reactor, defer
from twisted.internet.task import LoopingCall
import uptime
import logging
from inspect import getfullargspec, isfunction
from agent.SerialLineReceiverBase import SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext
# (relative import error.
from agent.context import BaseContext, get_context # noqa
Context = get_context()
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self):
self.transport.write(b'\r\n')
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': platform.system()})
def shutdown():
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
reactor.addSystemEventTrigger("before", "shutdown", shutdown)
def connectionLost(self, reason):
reactor.stop()
def connectionLost2(self, reason):
self.send_command(command='agent_stopped',
args={})
def tick(self):
logger.debug("Sending tick")
try:
self.send_status()
except Exception:
logger.exception("Twisted hide exception")
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.lc = LoopingCall(self.tick)
self.lc.start(5, now=False)
def send_status(self):
import psutil
disk_usage = dict((disk.device.replace('/', '_'),
psutil.disk_usage(disk.mountpoint).percent)
for disk in psutil.disk_partitions())
args = {"cpu": dict(psutil.cpu_times()._asdict()),
"ram": dict(psutil.virtual_memory()._asdict()),
"swap": dict(psutil.swap_memory()._asdict()),
"uptime": {"seconds": uptime.uptime()},
"disk": disk_usage,
"user": {"count": len(psutil.users())}}
self.send_response(response='status', args=args)
logger.debug("send_status finished")
def _check_args(self, func, args):
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
(self._pretty_fun(func), type(args).__name__))
# check for unexpected keyword arguments
argspec = getfullargspec(func)
if argspec.varkw is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs:
raise TypeError(
"Command %s got unexpected keyword arguments: %s" % (
self._pretty_fun(func), ", ".join(unexpected_kwargs)))
mandatory_args = argspec.args
if argspec.defaults: # remove those with default value
mandatory_args = mandatory_args[0:-len(argspec.defaults)]
missing_kwargs = set(mandatory_args) - set(args)
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
def _get_command(self, command, args):
if not isinstance(command, str) or command.startswith('_'):
raise AttributeError('Invalid command: %s' % command)
try:
func = getattr(Context, command)
except AttributeError as e:
raise AttributeError('Command not found: %s (%s)' % (command, e))
if not isfunction(func):
raise AttributeError("Command refers to non-static method %s." %
self._pretty_fun(func))
self._check_args(func, args)
logger.debug("_get_command finished")
return func
@staticmethod
def _pretty_fun(fun):
try:
argspec = getfullargspec(fun)
args = argspec.args
if argspec.varargs:
args.append("*" + argspec.varargs)
if argspec.varkw:
args.append("**" + argspec.varkw)
return "%s(%s)" % (fun.__name__, ",".join(args))
except Exception:
return "<%s>" % type(fun).__name__
def handle_command(self, command, args):
func = self._get_command(command, args)
retval = func(**args)
self.send_response(
response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)})
def handle_response(self, response, args):
pass
def init_serial():
# Get proper serial class and port name
(serial, port) = Context.get_serial()
logger.info("Opening port %s", port)
# Open serial connection
serial(SerialLineReceiver(), port, reactor)
try:
from agent.notify import register_publisher
register_publisher(reactor)
except Exception:
logger.exception("Could not register notify publisher")
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor finished.")
...@@ -4,71 +4,19 @@ ...@@ -4,71 +4,19 @@
import platform import platform
def _get_virtio_device():
path = None
GUID = '{6FDE7521-1B65-48ae-B628-80BE62016026}'
from infi.devicemanager import DeviceManager
dm = DeviceManager()
dm.root.rescan()
# Search Virtio-Serial by name TODO: search by class_guid
for i in dm.all_devices:
if i.has_property("description"):
if "virtio-serial".upper() in i.description.upper():
path = ("\\\\?\\" +
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
return path
def get_context(): def get_context():
system = platform.system() system = platform.system()
if system == "Windows": if system == "Windows":
from windows._win32context import Context from agent.windows._win32context import Context
elif system == "Linux": elif system == "Linux":
from linux._linuxcontext import Context from agent.linux._linuxcontext import Context
elif system == "FreeBSD": elif system == "FreeBSD":
from freebsd._freebsdcontext import Context from agent.freebsd._freebsdcontext import Context
else: else:
raise NotImplementedError("Platform %s is not supported.", system) raise NotImplementedError("Platform %s is not supported.", system)
return Context return Context
def get_serial():
system = platform.system()
port = None
if system == 'Windows':
port = _get_virtio_device()
import pythoncom
pythoncom.CoInitialize()
if port:
from windows.win32virtio import SerialPort
else:
from twisted.internet.serialport import SerialPort
port = r'\\.\COM1'
elif system == "Linux":
port = "/dev/virtio-ports/agent"
try:
open(port, 'rw').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyS0'
else:
from linux.posixvirtio import SerialPort
elif system == "FreeBSD":
port = "/dev/ttyV0.1"
try:
open(port, 'rw').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyu0'
else:
from freebsd.posixvirtio import SerialPort
else:
raise NotImplementedError("Platform %s is not supported.", system)
return (SerialPort, port)
class BaseContext(object): class BaseContext(object):
@staticmethod @staticmethod
def change_password(password): def change_password(password):
...@@ -136,5 +84,9 @@ class BaseContext(object): ...@@ -136,5 +84,9 @@ class BaseContext(object):
@staticmethod @staticmethod
def send_expiration(url): def send_expiration(url):
import notify from agent import notify
notify.notify(url) notify.notify(url)
@staticmethod
def get_serial():
raise NotImplementedError()
...@@ -27,14 +27,14 @@ import fileinput ...@@ -27,14 +27,14 @@ import fileinput
import tarfile import tarfile
from os.path import expanduser, join, exists from os.path import expanduser, join, exists
from glob import glob from glob import glob
from StringIO import StringIO from io import StringIO
from base64 import decodestring from base64 import decodestring
from hashlib import md5 from hashlib import md5
from ssh import PubKey from .ssh import PubKey
from .network import change_ip_freebsd from .network import change_ip_freebsd
from context import BaseContext from agent.context import BaseContext
from twisted.internet import reactor from twisted.internet import reactor
...@@ -69,7 +69,6 @@ class Context(BaseContext): ...@@ -69,7 +69,6 @@ class Context(BaseContext):
# python-module-to-change-system-date-and-time # python-module-to-change-system-date-and-time
@staticmethod @staticmethod
def _freebsd_set_time(time): def _freebsd_set_time(time):
import ctypes
import ctypes.util import ctypes.util
CLOCK_REALTIME = 0 CLOCK_REALTIME = 0
...@@ -135,10 +134,10 @@ class Context(BaseContext): ...@@ -135,10 +134,10 @@ class Context(BaseContext):
# TODO # TODO
for line in fileinput.input('/etc/fstab', inplace=True): for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' smbfs ' in line): if not (line.startswith('//') and ' smbfs ' in line):
print line.rstrip() print(line.rstrip())
with open(NSMBRC, 'w') as f: with open(NSMBRC, 'w') as f:
chmod(NSMBRC, 0600) chmod(NSMBRC, 0o600)
f.write(nsmbrc_template_freebsd % data) f.write(nsmbrc_template_freebsd % data)
with open('/etc/fstab', 'a') as f: with open('/etc/fstab', 'a') as f:
...@@ -155,21 +154,21 @@ class Context(BaseContext): ...@@ -155,21 +154,21 @@ class Context(BaseContext):
try: try:
retval.append(PubKey.from_str(line)) retval.append(PubKey.from_str(line))
except Exception: except Exception:
logger.exception(u'Invalid ssh key: ') logger.exception('Invalid ssh key: ')
except IOError: except IOError:
pass pass
return retval return retval
@staticmethod @staticmethod
def _save_keys(keys): def _save_keys(keys):
print keys print(keys)
try: try:
mkdir(SSH_DIR) mkdir(SSH_DIR)
except OSError: except OSError:
pass pass
with open(AUTHORIZED_KEYS, 'w') as f: with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys: for key in keys:
f.write(unicode(key) + '\n') f.write(str(key) + '\n')
@staticmethod @staticmethod
def add_keys(keys): def add_keys(keys):
...@@ -180,7 +179,7 @@ class Context(BaseContext): ...@@ -180,7 +179,7 @@ class Context(BaseContext):
if p not in new_keys: if p not in new_keys:
new_keys.append(p) new_keys.append(p)
except Exception: except Exception:
logger.exception(u'Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) Context._save_keys(new_keys)
@staticmethod @staticmethod
...@@ -194,7 +193,7 @@ class Context(BaseContext): ...@@ -194,7 +193,7 @@ class Context(BaseContext):
except ValueError: except ValueError:
pass pass
except Exception: except Exception:
logger.exception(u'Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) Context._save_keys(new_keys)
@staticmethod @staticmethod
...@@ -280,5 +279,17 @@ class Context(BaseContext): ...@@ -280,5 +279,17 @@ class Context(BaseContext):
@staticmethod @staticmethod
def send_expiration(url): def send_expiration(url):
import notify from agent import notify
notify.notify(url) notify.notify(url)
@staticmethod
def get_serial():
port = "/dev/ttyV0.1"
try:
open(port, 'rw').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyu0'
else:
from agent.freebsd import SerialPort
return (SerialPort, port)
...@@ -37,7 +37,7 @@ def change_ip_freebsd(interfaces, nameservers): ...@@ -37,7 +37,7 @@ def change_ip_freebsd(interfaces, nameservers):
data = list(get_interfaces_freebsd(interfaces)) data = list(get_interfaces_freebsd(interfaces))
for ifname, conf in data: for ifname, conf in data:
subprocess.call(('/usr/sbin/service', 'netif', 'stop', ifname)) subprocess.call(('/usr/sbin/service', 'netif', 'stop', ifname))
remove_interfaces_freebsd(dict(data).keys()) remove_interfaces_freebsd(list(dict(data).keys()))
for device, conf in data: for device, conf in data:
if_file = rcconf_dir + "ifconfig_" + device if_file = rcconf_dir + "ifconfig_" + device
......
...@@ -37,10 +37,10 @@ class PubKey(object): ...@@ -37,10 +37,10 @@ class PubKey(object):
PubKey.validate_key(key_type, key) PubKey.validate_key(key_type, key)
self.key = key self.key = key
self.comment = unicode(comment) self.comment = str(comment)
def __hash__(self): def __hash__(self):
return hash(frozenset(self.__dict__.items())) return hash(frozenset(list(self.__dict__.items())))
def __eq__(self, other): def __eq__(self, other):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
...@@ -51,10 +51,10 @@ class PubKey(object): ...@@ -51,10 +51,10 @@ class PubKey(object):
return PubKey(key_type, key, comment) return PubKey(key_type, key, comment)
def __unicode__(self): def __unicode__(self):
return u' '.join((self.key_type, self.key, self.comment)) return ' '.join((self.key_type, self.key, self.comment))
def __repr__(self): def __repr__(self):
return u'<PubKey: %s>' % unicode(self) return '<PubKey: %s>' % str(self)
# Unit tests # Unit tests
...@@ -85,11 +85,11 @@ class SshTestCase(unittest.TestCase): ...@@ -85,11 +85,11 @@ class SshTestCase(unittest.TestCase):
def test_unicode(self): def test_unicode(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment') p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment') self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_from_str(self): def test_from_str(self):
p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment') p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment') self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_eq(self): def test_eq(self):
self.assertEqual(self.p1, self.p2) self.assertEqual(self.p1, self.p2)
......
...@@ -17,15 +17,12 @@ import fileinput ...@@ -17,15 +17,12 @@ import fileinput
import tarfile import tarfile
from os.path import expanduser, join, exists from os.path import expanduser, join, exists
from glob import glob from glob import glob
from StringIO import StringIO from io import StringIO
from base64 import decodestring from base64 import decodestring
from hashlib import md5 from hashlib import md5
from agent.linux.ssh import PubKey
from agent.linux.network import change_ip_ubuntu, change_ip_rhel
from ssh import PubKey from agent.context import BaseContext
from .network import change_ip_ubuntu, change_ip_rhel
from context import BaseContext
from twisted.internet import reactor from twisted.internet import reactor
...@@ -56,7 +53,6 @@ class Context(BaseContext): ...@@ -56,7 +53,6 @@ class Context(BaseContext):
# python-module-to-change-system-date-and-time # python-module-to-change-system-date-and-time
@staticmethod @staticmethod
def _linux_set_time(time): def _linux_set_time(time):
import ctypes
import ctypes.util import ctypes.util
CLOCK_REALTIME = 0 CLOCK_REALTIME = 0
...@@ -77,7 +73,7 @@ class Context(BaseContext): ...@@ -77,7 +73,7 @@ class Context(BaseContext):
def change_password(password): def change_password(password):
proc = subprocess.Popen(['/usr/sbin/chpasswd'], proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE) stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password) proc.communicate(('cloud:%s\n' % password).encode())
@staticmethod @staticmethod
def restart_networking(): def restart_networking():
...@@ -123,7 +119,7 @@ class Context(BaseContext): ...@@ -123,7 +119,7 @@ class Context(BaseContext):
# TODO # TODO
for line in fileinput.input('/etc/fstab', inplace=True): for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line): if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip() print(line.rstrip())
with open('/etc/fstab', 'a') as f: with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data) f.write(mount_template_linux % data)
...@@ -139,7 +135,7 @@ class Context(BaseContext): ...@@ -139,7 +135,7 @@ class Context(BaseContext):
try: try:
retval.append(PubKey.from_str(line)) retval.append(PubKey.from_str(line))
except Exception: except Exception:
logger.exception(u'Invalid ssh key: ') logger.exception('Invalid ssh key: ')
except IOError: except IOError:
pass pass
return retval return retval
...@@ -152,7 +148,7 @@ class Context(BaseContext): ...@@ -152,7 +148,7 @@ class Context(BaseContext):
pass pass
with open(AUTHORIZED_KEYS, 'w') as f: with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys: for key in keys:
f.write(unicode(key) + '\n') f.write(str(key) + '\n')
uid = getpwnam("cloud").pw_uid uid = getpwnam("cloud").pw_uid
chown(SSH_DIR, uid, -1) chown(SSH_DIR, uid, -1)
...@@ -167,7 +163,7 @@ class Context(BaseContext): ...@@ -167,7 +163,7 @@ class Context(BaseContext):
if p not in new_keys: if p not in new_keys:
new_keys.append(p) new_keys.append(p)
except Exception: except Exception:
logger.exception(u'Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) Context._save_keys(new_keys)
@staticmethod @staticmethod
...@@ -181,7 +177,7 @@ class Context(BaseContext): ...@@ -181,7 +177,7 @@ class Context(BaseContext):
except ValueError: except ValueError:
pass pass
except Exception: except Exception:
logger.exception(u'Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) Context._save_keys(new_keys)
@staticmethod @staticmethod
...@@ -263,5 +259,17 @@ class Context(BaseContext): ...@@ -263,5 +259,17 @@ class Context(BaseContext):
@staticmethod @staticmethod
def send_expiration(url): def send_expiration(url):
import notify from agent import notify
notify.notify(url) notify.notify(url)
@staticmethod
def get_serial():
port = "/dev/virtio-ports/agent"
try:
open(port, 'w').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyS0'
else:
from agent.linux.posixvirtio import SerialPort
return (SerialPort, port)
...@@ -27,7 +27,7 @@ def remove_interfaces_ubuntu(devices): ...@@ -27,7 +27,7 @@ def remove_interfaces_ubuntu(devices):
if line.startswith('#') or line == '' or line.isspace() or not words: if line.startswith('#') or line == '' or line.isspace() or not words:
# keep line # keep line
print line print(line)
continue continue
if (words[0] in ('auto', 'allow-hotplug') and if (words[0] in ('auto', 'allow-hotplug') and
...@@ -49,7 +49,7 @@ def remove_interfaces_ubuntu(devices): ...@@ -49,7 +49,7 @@ def remove_interfaces_ubuntu(devices):
continue continue
# keep line # keep line
print line print(line)
def change_ip_ubuntu(interfaces, nameservers): def change_ip_ubuntu(interfaces, nameservers):
...@@ -60,7 +60,7 @@ def change_ip_ubuntu(interfaces, nameservers): ...@@ -60,7 +60,7 @@ def change_ip_ubuntu(interfaces, nameservers):
subprocess.call(('/sbin/ip', 'addr', 'flush', 'dev', ifname)) subprocess.call(('/sbin/ip', 'addr', 'flush', 'dev', ifname))
subprocess.call(('/sbin/ip', 'link', 'set', 'dev', ifname, subprocess.call(('/sbin/ip', 'link', 'set', 'dev', ifname,
'down')) 'down'))
remove_interfaces_ubuntu(dict(data).keys()) remove_interfaces_ubuntu(list(dict(data).keys()))
with open(interfaces_file, 'a') as f: with open(interfaces_file, 'a') as f:
for ifname, conf in data: for ifname, conf in data:
......
...@@ -37,10 +37,10 @@ class PubKey(object): ...@@ -37,10 +37,10 @@ class PubKey(object):
PubKey.validate_key(key_type, key) PubKey.validate_key(key_type, key)
self.key = key self.key = key
self.comment = unicode(comment) self.comment = str(comment)
def __hash__(self): def __hash__(self):
return hash(frozenset(self.__dict__.items())) return hash(frozenset(list(self.__dict__.items())))
def __eq__(self, other): def __eq__(self, other):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
...@@ -51,10 +51,10 @@ class PubKey(object): ...@@ -51,10 +51,10 @@ class PubKey(object):
return PubKey(key_type, key, comment) return PubKey(key_type, key, comment)
def __unicode__(self): def __unicode__(self):
return u' '.join((self.key_type, self.key, self.comment)) return ' '.join((self.key_type, self.key, self.comment))
def __repr__(self): def __repr__(self):
return u'<PubKey: %s>' % unicode(self) return '<PubKey: %s>' % str(self)
# Unit tests # Unit tests
...@@ -85,11 +85,11 @@ class SshTestCase(unittest.TestCase): ...@@ -85,11 +85,11 @@ class SshTestCase(unittest.TestCase):
def test_unicode(self): def test_unicode(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment') p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment') self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_from_str(self): def test_from_str(self):
p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment') p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment') self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_eq(self): def test_eq(self):
self.assertEqual(self.p1, self.p2) self.assertEqual(self.p1, self.p2)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# Notify user about vm expiring # Notify user about vm expiring
## ##
import cookielib import http.cookiejar
import errno import errno
import json import json
import logging import logging
...@@ -13,8 +13,8 @@ import multiprocessing ...@@ -13,8 +13,8 @@ import multiprocessing
import os import os
import platform import platform
import subprocess import subprocess
import urllib2 import urllib.request, urllib.error, urllib.parse
from urlparse import urlsplit from urllib.parse import urlsplit
logger = logging.getLogger() logger = logging.getLogger()
logger.debug("notify imported") logger.debug("notify imported")
...@@ -58,19 +58,19 @@ def accept(): ...@@ -58,19 +58,19 @@ def accept():
from pytz import UTC from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name) file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
print "There is no recent notification to accept." print("There is no recent notification to accept.")
return False return False
# Load the saved url # Load the saved url
url = json.load(open(file_path, "r")) url = json.load(open(file_path, "r"))
cj = cookielib.CookieJar() cj = http.cookiejar.CookieJar()
opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(cj)) opener = urllib.request.build_opener(urllib.request.HTTPCookieProcessor(cj))
try: try:
opener.open(url) # GET to collect cookies opener.open(url) # GET to collect cookies
cookies = cj._cookies_for_request(urllib2.Request(url)) cookies = cj._cookies_for_request(urllib.request.Request(url))
token = [c for c in cookies if c.name == "csrftoken"][0].value token = [c for c in cookies if c.name == "csrftoken"][0].value
req = urllib2.Request(url, "", { req = urllib.request.Request(url, "", {
"accept": "application/json", "referer": url, "accept": "application/json", "referer": url,
"x-csrftoken": token}) "x-csrftoken": token})
rsp = opener.open(req) rsp = opener.open(req)
...@@ -83,10 +83,10 @@ def accept(): ...@@ -83,10 +83,10 @@ def accept():
new_local_time = parsed_time.astimezone( new_local_time = parsed_time.astimezone(
get_localzone()).strftime("%Y-%m-%d %H:%M:%S") get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e: except ValueError as e:
print "Parsing time failed: %s" % e print("Parsing time failed: %s" % e)
except Exception as e: except Exception as e:
print e print(e)
print "Renewal failed. Please try it manually at %s" % url print("Renewal failed. Please try it manually at %s" % url)
logger.exception("renew failed") logger.exception("renew failed")
return False return False
else: else:
...@@ -102,7 +102,7 @@ def notify(url): ...@@ -102,7 +102,7 @@ def notify(url):
if win: if win:
logger.info("notifying %d clients", len(clients)) logger.info("notifying %d clients", len(clients))
for c in clients: for c in clients:
logger.debug("sending url %s to client %s", url, unicode(c)) logger.debug("sending url %s to client %s", url, str(c))
c.sendLine(url.encode()) c.sendLine(url.encode())
else: else:
file_path = os.path.join(get_temp_dir(), file_name) file_path = os.path.join(get_temp_dir(), file_name)
...@@ -209,11 +209,11 @@ if win: ...@@ -209,11 +209,11 @@ if win:
self.factory = factory self.factory = factory
def connectionMade(self): def connectionMade(self):
logger.info("client connected: %s", unicode(self)) logger.info("client connected: %s", str(self))
clients.add(self) clients.add(self)
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("client disconnected: %s", unicode(self)) logger.info("client disconnected: %s", str(self))
clients.remove(self) clients.remove(self)
class PubFactory(protocol.Factory): class PubFactory(protocol.Factory):
...@@ -230,7 +230,7 @@ if win: ...@@ -230,7 +230,7 @@ if win:
class SubProtocol(basic.LineReceiver): class SubProtocol(basic.LineReceiver):
def lineReceived(self, line): def lineReceived(self, line):
print "received", line print("received", line)
if line.startswith('cifs://'): if line.startswith('cifs://'):
mount_smb(line) mount_smb(line)
else: else:
...@@ -243,7 +243,7 @@ if win: ...@@ -243,7 +243,7 @@ if win:
def run_client(): def run_client():
from twisted.internet import reactor from twisted.internet import reactor
print "connect to localhost:%d" % port print("connect to localhost:%d" % port)
reactor.connectTCP("localhost", port, SubFactory()) reactor.connectTCP("localhost", port, SubFactory())
reactor.run() reactor.run()
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
working_directory = r"C:\circle" # noqa
from os.path import join from os.path import join
import logging import logging
import tarfile from tarfile import TarFile, ReadError
from StringIO import StringIO from io import StringIO
from base64 import decodestring from base64 import b64decode
from hashlib import md5 from hashlib import md5
from datetime import datetime from datetime import datetime
import win32api import win32api
import wmi import wmi
import netifaces import netifaces
from twisted.internet import reactor from twisted.internet import reactor
from agent.windows.network import change_ip_windows
from .network import change_ip_windows from agent.context import BaseContext
from context import BaseContext
logger = logging.getLogger() logger = logging.getLogger()
working_directory = r"C:\circle" # noqa
class Context(BaseContext): class Context(BaseContext):
...@@ -54,10 +48,10 @@ class Context(BaseContext): ...@@ -54,10 +48,10 @@ class Context(BaseContext):
@staticmethod @staticmethod
def mount_store(host, username, password): def mount_store(host, username, password):
import notify from agent import notify
url = 'cifs://%s:%s@%s/%s' % (username, password, host, username) url = 'cifs://%s:%s@%s/%s' % (username, password, host, username)
for c in notify.clients: for c in notify.clients:
logger.debug("sending url %s to client %s", url, unicode(c)) logger.debug("sending url %s to client %s", url, str(c))
c.sendLine(url.encode()) c.sendLine(url.encode())
@staticmethod @staticmethod
...@@ -94,7 +88,7 @@ class Context(BaseContext): ...@@ -94,7 +88,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def _update_registry(dir, executable): def _update_registry(dir, executable):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent # HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from _winreg import (OpenKeyEx, SetValueEx, QueryValueEx, from winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS) HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS)
with OpenKeyEx(HKEY_LOCAL_MACHINE, with OpenKeyEx(HKEY_LOCAL_MACHINE,
r'SYSTEM\CurrentControlSet\services\circle-agent', r'SYSTEM\CurrentControlSet\services\circle-agent',
...@@ -111,11 +105,11 @@ class Context(BaseContext): ...@@ -111,11 +105,11 @@ class Context(BaseContext):
local_checksum = md5(data).hexdigest() local_checksum = md5(data).hexdigest()
if local_checksum != checksum: if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.") raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data)) decoded = StringIO(b64decode(data))
try: try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz') tar = TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory) tar.extractall(working_directory)
except tarfile.ReadError as e: except ReadError as e:
logger.error(e) logger.error(e)
logger.info("Transfer completed!") logger.info("Transfer completed!")
Context._update_registry(working_directory, executable) Context._update_registry(working_directory, executable)
...@@ -145,3 +139,32 @@ class Context(BaseContext): ...@@ -145,3 +139,32 @@ class Context(BaseContext):
return f.readline() return f.readline()
except IOError: except IOError:
return None return None
@staticmethod
def _get_virtio_device():
path = None
GUID = '{6FDE7521-1B65-48ae-B628-80BE62016026}'
from infi.devicemanager import DeviceManager
dm = DeviceManager()
dm.root.rescan()
# Search Virtio-Serial by name TODO: search by class_guid
for i in dm.all_devices:
if i.has_property("description"):
if "virtio-serial".upper() in i.description.upper():
path = ("\\\\?\\" +
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
return path
@staticmethod
def get_serial():
port = Context._get_virtio_device()
import pythoncom
pythoncom.CoInitialize()
if port:
from agent.windows.win32virtio import SerialPort
else:
from twisted.internet.serialport import SerialPort
port = r'\\.\COM1'
return (SerialPort, port)
...@@ -81,9 +81,9 @@ class SerialPort(abstract.FileDescriptor): ...@@ -81,9 +81,9 @@ class SerialPort(abstract.FileDescriptor):
except Exception: except Exception:
import time import time
time.sleep(10) time.sleep(10)
n = 0 n = None
if n: if n:
first = str(self.read_buf[:n]) first = bytes(self.read_buf[:n])
# now we should get everything that is already in the buffer (max # now we should get everything that is already in the buffer (max
# 4096) # 4096)
win32event.ResetEvent(self._overlappedRead.hEvent) win32event.ResetEvent(self._overlappedRead.hEvent)
...@@ -95,7 +95,7 @@ class SerialPort(abstract.FileDescriptor): ...@@ -95,7 +95,7 @@ class SerialPort(abstract.FileDescriptor):
self._overlappedRead, self._overlappedRead,
1) 1)
# handle all the received data: # handle all the received data:
self.protocol.dataReceived(first + str(buf[:n])) self.protocol.dataReceived(first + bytes(buf[:n]))
# set up next one # set up next one
win32event.ResetEvent(self._overlappedRead.hEvent) win32event.ResetEvent(self._overlappedRead.hEvent)
rc, self.read_buf = win32file.ReadFile(self.hComPort, rc, self.read_buf = win32file.ReadFile(self.hComPort,
......
...@@ -11,7 +11,7 @@ fh.setFormatter(formatter) ...@@ -11,7 +11,7 @@ fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
from notify import run_client from agent.notify import run_client
if __name__ == '__main__': if __name__ == '__main__':
run_client() run_client()
#!/usr/bin/env python #!/usr/bin/env python
import notify from agent import notify
if __name__ == '__main__': if __name__ == '__main__':
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment