Commit eb2eeb12 by carpoon

code cleanup and refactoring

 * import cleanup
 * Create distinct names classes for system specific Contexts
 * Move SerialLineReciver to a separate file, pass Context for it in the constructor
 * Move nameserver parsing to base class
parent 2fe1567b
...@@ -7,7 +7,7 @@ import sys ...@@ -7,7 +7,7 @@ import sys
import win32event import win32event
import win32serviceutil import win32serviceutil
import win32service import win32service
from agent.agent import init_serial, reactor from agent.main import init_serial, reactor
logger = logging.getLogger() logger = logging.getLogger()
fh = NTEventLogHandler("CIRCLE Agent", dllname=os.path.dirname(__file__)) fh = NTEventLogHandler("CIRCLE Agent", dllname=os.path.dirname(__file__))
......
...@@ -16,13 +16,7 @@ try: # noqa ...@@ -16,13 +16,7 @@ try: # noqa
except Exception as e: # noqa except Exception as e: # noqa
pass # hope it works # noqa pass # hope it works # noqa
from agent.agent import init_serial, reactor from agent.main import init_serial
# Note: Import everything because later we need to use the BaseContext
# (relative import error.
from agent.context import BaseContext, get_context # noqa
Context = get_context()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig()
......
""" This is the defautl context file. It replaces the Context class """ This is the default context file. It replaces the Context class
to the platform specific one. to the platform specific one.
""" """
import platform
def get_context():
system = platform.system()
if system == "Windows":
from agent.windows._win32context import Context
elif system == "Linux":
from agent.linux._linuxcontext import Context
elif system == "FreeBSD":
from agent.freebsd._freebsdcontext import Context
else:
raise NotImplementedError("Platform %s is not supported.", system)
return Context
class BaseContext(object): class BaseContext(object):
...@@ -90,3 +76,7 @@ class BaseContext(object): ...@@ -90,3 +76,7 @@ class BaseContext(object):
@staticmethod @staticmethod
def get_serial(): def get_serial():
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def _parse_nameserver(nameservers):
return [s.strip() for s in nameservers.split(',')]
#!/usr/bin/env python from agent.SerialLineReceiverBase import SerialLineReceiverBase
# -*- coding: utf-8 -*- from agent.BaseContext import BaseContext
from os import environ
import platform import platform
from twisted.internet import reactor, defer import logging
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.internet import reactor, defer
import uptime import uptime
import logging
from inspect import getfullargspec, isfunction from inspect import getfullargspec, isfunction
from agent.SerialLineReceiverBase import SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext
# (relative import error.
from agent.context import BaseContext, get_context # noqa
Context = get_context()
logging.basicConfig()
logger = logging.getLogger() logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
class SerialLineReceiver(SerialLineReceiverBase): class SerialLineReceiver(SerialLineReceiverBase):
...@@ -32,7 +16,7 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -32,7 +16,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
self.transport.write(b'\r\n') self.transport.write(b'\r\n')
self.send_command( self.send_command(
command='agent_started', command='agent_started',
args={'version': Context.get_agent_version(), 'system': platform.system()}) args={'version': self.context.get_agent_version(), 'system': platform.system()})
def shutdown(): def shutdown():
self.connectionLost2('shutdown') self.connectionLost2('shutdown')
...@@ -45,8 +29,7 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -45,8 +29,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
reactor.stop() reactor.stop()
def connectionLost2(self, reason): def connectionLost2(self, reason):
self.send_command(command='agent_stopped', self.send_command(command='agent_stopped', args={})
args={})
def tick(self): def tick(self):
logger.debug("Sending tick") logger.debug("Sending tick")
...@@ -55,10 +38,11 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -55,10 +38,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
except Exception: except Exception:
logger.exception("Twisted hide exception") logger.exception("Twisted hide exception")
def __init__(self): def __init__(self, context: BaseContext):
super(SerialLineReceiver, self).__init__() super(SerialLineReceiver, self).__init__()
self.lc = LoopingCall(self.tick) self.lc = LoopingCall(self.tick)
self.lc.start(5, now=False) self.lc.start(5, now=False)
self.context = context
def send_status(self): def send_status(self):
import psutil import psutil
...@@ -101,13 +85,12 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -101,13 +85,12 @@ class SerialLineReceiver(SerialLineReceiverBase):
if not isinstance(command, str) or command.startswith('_'): if not isinstance(command, str) or command.startswith('_'):
raise AttributeError('Invalid command: %s' % command) raise AttributeError('Invalid command: %s' % command)
try: try:
func = getattr(Context, command) func = getattr(self.context, command)
except AttributeError as e: except AttributeError as e:
raise AttributeError('Command not found: %s (%s)' % (command, e)) raise AttributeError('Command not found: %s (%s)' % (command, e))
if not isfunction(func): if not isfunction(func):
raise AttributeError("Command refers to non-static method %s." % raise AttributeError("Command refers to non-static method %s." % self._pretty_fun(func))
self._pretty_fun(func))
self._check_args(func, args) self._check_args(func, args)
logger.debug("_get_command finished") logger.debug("_get_command finished")
...@@ -135,19 +118,3 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -135,19 +118,3 @@ class SerialLineReceiver(SerialLineReceiverBase):
def handle_response(self, response, args): def handle_response(self, response, args):
pass pass
def init_serial():
# Get proper serial class and port name
(serial, port) = Context.get_serial()
logger.info("Opening port %s", port)
# Open serial connection
serial(SerialLineReceiver(), port, reactor)
try:
from agent.notify import register_publisher
register_publisher(reactor)
except Exception:
logger.exception("Could not register notify publisher")
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor finished.")
...@@ -34,7 +34,7 @@ from hashlib import md5 ...@@ -34,7 +34,7 @@ from hashlib import md5
from .ssh import PubKey from .ssh import PubKey
from .network import change_ip_freebsd from .network import change_ip_freebsd
from agent.context import BaseContext from agent.BaseContext import BaseContext
from twisted.internet import reactor from twisted.internet import reactor
...@@ -63,7 +63,7 @@ nsmbrc_template_freebsd = ( ...@@ -63,7 +63,7 @@ nsmbrc_template_freebsd = (
'password=%(password)s\n') 'password=%(password)s\n')
class Context(BaseContext): class FreeBSDContext(BaseContext):
# http://stackoverflow.com/questions/12081310/ # http://stackoverflow.com/questions/12081310/
# python-module-to-change-system-date-and-time # python-module-to-change-system-date-and-time
...@@ -90,11 +90,11 @@ class Context(BaseContext): ...@@ -90,11 +90,11 @@ class Context(BaseContext):
proc0 = subprocess.Popen( proc0 = subprocess.Popen(
['/usr/sbin/pw', 'user', 'mod', 'cloud', '-h', '0'], ['/usr/sbin/pw', 'user', 'mod', 'cloud', '-h', '0'],
stdin=subprocess.PIPE) stdin=subprocess.PIPE)
proc0.communicate('%s\n' % password) proc0.communicate(('%s\n' % password).encode())
proc1 = subprocess.Popen( proc1 = subprocess.Popen(
['/usr/sbin/pw', 'user', 'mod', 'root', '-h', '0'], ['/usr/sbin/pw', 'user', 'mod', 'root', '-h', '0'],
stdin=subprocess.PIPE) stdin=subprocess.PIPE)
proc1.communicate('%s\n' % password) proc1.communicate(('%s\n' % password).encode())
@staticmethod @staticmethod
def restart_networking(): def restart_networking():
...@@ -102,12 +102,12 @@ class Context(BaseContext): ...@@ -102,12 +102,12 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_ip(interfaces, dns): def change_ip(interfaces, dns):
nameservers = dns.replace(' ', '').split(',') nameservers = FreeBSDContext._parse_nameserver(dns)
change_ip_freebsd(interfaces, nameservers) change_ip_freebsd(interfaces, nameservers)
@staticmethod @staticmethod
def set_time(time): def set_time(time):
Context._freebsd_set_time(float(time)) FreeBSDContext._freebsd_set_time(float(time))
try: try:
subprocess.call(['/usr/sbin/service' 'ntpd', 'onerestart']) subprocess.call(['/usr/sbin/service' 'ntpd', 'onerestart'])
except Exception: except Exception:
...@@ -172,7 +172,7 @@ class Context(BaseContext): ...@@ -172,7 +172,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def add_keys(keys): def add_keys(keys):
new_keys = Context.get_keys() new_keys = FreeBSDContext.get_keys()
for key in keys: for key in keys:
try: try:
p = PubKey.from_str(key) p = PubKey.from_str(key)
...@@ -180,11 +180,11 @@ class Context(BaseContext): ...@@ -180,11 +180,11 @@ class Context(BaseContext):
new_keys.append(p) new_keys.append(p)
except Exception: except Exception:
logger.exception('Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) FreeBSDContext._save_keys(new_keys)
@staticmethod @staticmethod
def del_keys(keys): def del_keys(keys):
new_keys = Context.get_keys() new_keys = FreeBSDContext.get_keys()
for key in keys: for key in keys:
try: try:
p = PubKey.from_str(key) p = PubKey.from_str(key)
...@@ -194,7 +194,7 @@ class Context(BaseContext): ...@@ -194,7 +194,7 @@ class Context(BaseContext):
pass pass
except Exception: except Exception:
logger.exception('Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) FreeBSDContext._save_keys(new_keys)
@staticmethod @staticmethod
def cleanup(): def cleanup():
......
...@@ -22,7 +22,7 @@ from base64 import decodestring ...@@ -22,7 +22,7 @@ from base64 import decodestring
from hashlib import md5 from hashlib import md5
from agent.linux.ssh import PubKey from agent.linux.ssh import PubKey
from agent.linux.network import change_ip_ubuntu, change_ip_rhel from agent.linux.network import change_ip_ubuntu, change_ip_rhel
from agent.context import BaseContext from agent.BaseContext import BaseContext
from twisted.internet import reactor from twisted.internet import reactor
...@@ -47,7 +47,7 @@ distros = {'Scientific Linux': 'rhel', ...@@ -47,7 +47,7 @@ distros = {'Scientific Linux': 'rhel',
distro = distros[platform.linux_distribution()[0]] distro = distros[platform.linux_distribution()[0]]
class Context(BaseContext): class LinuxContext(BaseContext):
# http://stackoverflow.com/questions/12081310/ # http://stackoverflow.com/questions/12081310/
# python-module-to-change-system-date-and-time # python-module-to-change-system-date-and-time
...@@ -85,7 +85,7 @@ class Context(BaseContext): ...@@ -85,7 +85,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_ip(interfaces, dns): def change_ip(interfaces, dns):
nameservers = dns.replace(' ', '').split(',') nameservers = LinuxContext._parse_nameserver(dns)
if distro == 'debian': if distro == 'debian':
change_ip_ubuntu(interfaces, nameservers) change_ip_ubuntu(interfaces, nameservers)
elif distro == 'rhel': elif distro == 'rhel':
...@@ -93,7 +93,7 @@ class Context(BaseContext): ...@@ -93,7 +93,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def set_time(time): def set_time(time):
Context._linux_set_time(float(time)) LinuxContext._linux_set_time(float(time))
try: try:
subprocess.call(['/etc/init.d/ntp', 'restart']) subprocess.call(['/etc/init.d/ntp', 'restart'])
except Exception: except Exception:
...@@ -112,8 +112,7 @@ class Context(BaseContext): ...@@ -112,8 +112,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def mount_store(host, username, password): def mount_store(host, username, password):
data = {'host': host, 'username': username, 'password': password} data = {'host': host, 'username': username, 'password': password, 'dir': STORE_DIR}
data['dir'] = STORE_DIR
if not exists(STORE_DIR): if not exists(STORE_DIR):
mkdir(STORE_DIR) mkdir(STORE_DIR)
# TODO # TODO
...@@ -156,7 +155,7 @@ class Context(BaseContext): ...@@ -156,7 +155,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def add_keys(keys): def add_keys(keys):
new_keys = Context.get_keys() new_keys = LinuxContext.get_keys()
for key in keys: for key in keys:
try: try:
p = PubKey.from_str(key) p = PubKey.from_str(key)
...@@ -164,11 +163,11 @@ class Context(BaseContext): ...@@ -164,11 +163,11 @@ class Context(BaseContext):
new_keys.append(p) new_keys.append(p)
except Exception: except Exception:
logger.exception('Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) LinuxContext._save_keys(new_keys)
@staticmethod @staticmethod
def del_keys(keys): def del_keys(keys):
new_keys = Context.get_keys() new_keys = LinuxContext.get_keys()
for key in keys: for key in keys:
try: try:
p = PubKey.from_str(key) p = PubKey.from_str(key)
...@@ -178,7 +177,7 @@ class Context(BaseContext): ...@@ -178,7 +177,7 @@ class Context(BaseContext):
pass pass
except Exception: except Exception:
logger.exception('Invalid ssh key: ') logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys) LinuxContext._save_keys(new_keys)
@staticmethod @staticmethod
def cleanup(): def cleanup():
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os import environ
import platform
from twisted.internet import reactor, defer
import logging
from agent.SerialLineReceiver import SerialLineReceiver
def get_context():
system = platform.system()
if system == "Windows":
from agent.windows.WindowsContext import WindowsContext as Context
elif system == "Linux":
from agent.linux.LinuxContext import LinuxContext as Context
elif system == "FreeBSD":
from agent.freebsd.FreeBSDContext import FreeBSDContext as Context
else:
raise NotImplementedError("Platform %s is not supported.", system)
return Context
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
def init_serial():
context = get_context()
# Get proper serial class and port name
(serial, port) = context.get_serial()
logger.info("Opening port %s", port)
# Open serial connection
serial(SerialLineReceiver(context), port, reactor)
try:
from agent.notify import register_publisher
register_publisher(reactor)
except Exception:
logger.exception("Could not register notify publisher")
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor finished.")
...@@ -49,7 +49,7 @@ def wall(text): ...@@ -49,7 +49,7 @@ def wall(text):
logger.error("Incorrect function call") logger.error("Incorrect function call")
else: else:
process = subprocess.Popen("wall", stdin=subprocess.PIPE, shell=True) process = subprocess.Popen("wall", stdin=subprocess.PIPE, shell=True)
process.communicate(input=text)[0] process.communicate(input=text.encode())
def accept(): def accept():
......
...@@ -12,13 +12,13 @@ import wmi ...@@ -12,13 +12,13 @@ import wmi
import netifaces import netifaces
from twisted.internet import reactor from twisted.internet import reactor
from agent.windows.network import change_ip_windows from agent.windows.network import change_ip_windows
from agent.context import BaseContext from agent.BaseContext import BaseContext
logger = logging.getLogger() logger = logging.getLogger()
working_directory = r"C:\circle" # noqa working_directory = r"C:\circle" # noqa
class Context(BaseContext): class WindowsContext(BaseContext):
@staticmethod @staticmethod
def change_password(password): def change_password(password):
...@@ -33,7 +33,7 @@ class Context(BaseContext): ...@@ -33,7 +33,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_ip(interfaces, dns): def change_ip(interfaces, dns):
nameservers = dns.replace(' ', '').split(',') nameservers = WindowsContext._parse_nameserver(dns)
change_ip_windows(interfaces, nameservers) change_ip_windows(interfaces, nameservers)
@staticmethod @staticmethod
...@@ -112,7 +112,7 @@ class Context(BaseContext): ...@@ -112,7 +112,7 @@ class Context(BaseContext):
except ReadError as e: except ReadError as e:
logger.error(e) logger.error(e)
logger.info("Transfer completed!") logger.info("Transfer completed!")
Context._update_registry(working_directory, executable) WindowsContext._update_registry(working_directory, executable)
logger.info('Updated') logger.info('Updated')
reactor.stop() reactor.stop()
...@@ -159,7 +159,7 @@ class Context(BaseContext): ...@@ -159,7 +159,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def get_serial(): def get_serial():
port = Context._get_virtio_device() port = WindowsContext._get_virtio_device()
import pythoncom import pythoncom
pythoncom.CoInitialize() pythoncom.CoInitialize()
if port: if port:
......
# Open urls in default web browser provided by circle agent # Open urls in default web browser provided by circle agent
# Part of CIRCLE project http://circlecloud.org/ # Part of CIRCLE project http://circlecloud.org/
# Should be in autostart and run by the user logged in # Should be in autostart and run by the user logged in
import logging import logging
from agent.notify import run_client
logger = logging.getLogger() logger = logging.getLogger()
fh = logging.FileHandler("agent-client.log") fh = logging.FileHandler("agent-client.log")
formatter = logging.Formatter( formatter = logging.Formatter(
...@@ -11,7 +12,5 @@ fh.setFormatter(formatter) ...@@ -11,7 +12,5 @@ fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
from agent.notify import run_client
if __name__ == '__main__': if __name__ == '__main__':
run_client() run_client()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment