Commit 382cea4d by carpoon

migration first step

py2to3
Move all sources to agent folder to help imports
rename utils to SerialLineReceiverBase as it only contains this class
fix string and binary storage differences in python3
update requirements, add windows specific requirements
move platform specific logic to the platform specific classes
parent 6bdaacab
......@@ -5,16 +5,13 @@ import servicemanager
import socket
import sys
import win32event
import win32service
import win32serviceutil
from agent import main as agent_main, reactor
import win32service
from agent.agent import init_serial, reactor
logger = logging.getLogger()
fh = NTEventLogHandler(
"CIRCLE Agent", dllname=os.path.dirname(__file__))
formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh = NTEventLogHandler("CIRCLE Agent", dllname=os.path.dirname(__file__))
formatter = logging.Formatter("%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO')
......@@ -42,7 +39,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager.PYS_SERVICE_STARTED,
(self._svc_name_, ''))
logger.info("%s starting", __file__)
agent_main()
init_serial()
def main():
......@@ -68,4 +65,4 @@ if __name__ == '__main__':
except (SystemExit, KeyboardInterrupt):
raise
except Exception:
logger.exception("Exception:")
logger.exception("Exception: %s" % str(Exception))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os import environ, chdir
import platform
import subprocess
import sys
system = platform.system() # noqa
if system == "Linux" or system == "FreeBSD": # noqa
try: # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
try: # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
from twisted.internet import reactor, defer
......@@ -23,150 +18,22 @@ from twisted.internet.task import LoopingCall
import uptime
import logging
from inspect import getargspec, isfunction
from utils import SerialLineReceiverBase
from inspect import getfullargspec, isfunction
from agent.SerialLineReceiverBase import SerialLineReceiverBase
from agent.agent import init_serial, reactor
# Note: Import everything because later we need to use the BaseContext
# (relative import error.
from context import BaseContext, get_context, get_serial # noqa
from agent.context import BaseContext, get_context # noqa
Context = get_context()
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self):
self.transport.write('\r\n')
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(),
'system': system})
def shutdown():
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
reactor.addSystemEventTrigger("before", "shutdown", shutdown)
def connectionLost(self, reason):
reactor.stop()
def connectionLost2(self, reason):
self.send_command(command='agent_stopped',
args={})
def tick(self):
logger.debug("Sending tick")
try:
self.send_status()
except Exception:
logger.exception("Twisted hide exception")
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.lc = LoopingCall(self.tick)
self.lc.start(5, now=False)
def send_status(self):
import psutil
disk_usage = dict((disk.device.replace('/', '_'),
psutil.disk_usage(disk.mountpoint).percent)
for disk in psutil.disk_partitions())
args = {"cpu": dict(psutil.cpu_times()._asdict()),
"ram": dict(psutil.virtual_memory()._asdict()),
"swap": dict(psutil.swap_memory()._asdict()),
"uptime": {"seconds": uptime.uptime()},
"disk": disk_usage,
"user": {"count": len(psutil.users())}}
self.send_response(response='status', args=args)
logger.debug("send_status finished")
def _check_args(self, func, args):
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
(self._pretty_fun(func), type(args).__name__))
# check for unexpected keyword arguments
argspec = getargspec(func)
if argspec.keywords is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs:
raise TypeError(
"Command %s got unexpected keyword arguments: %s" % (
self._pretty_fun(func), ", ".join(unexpected_kwargs)))
mandatory_args = argspec.args
if argspec.defaults: # remove those with default value
mandatory_args = mandatory_args[0:-len(argspec.defaults)]
missing_kwargs = set(mandatory_args) - set(args)
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
def _get_command(self, command, args):
if not isinstance(command, basestring) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command)
try:
func = getattr(Context, command)
except AttributeError as e:
raise AttributeError(u'Command not found: %s (%s)' % (command, e))
if not isfunction(func):
raise AttributeError("Command refers to non-static method %s." %
self._pretty_fun(func))
self._check_args(func, args)
logger.debug("_get_command finished")
return func
@staticmethod
def _pretty_fun(fun):
try:
argspec = getargspec(fun)
args = argspec.args
if argspec.varargs:
args.append("*" + argspec.varargs)
if argspec.keywords:
args.append("**" + argspec.keywords)
return "%s(%s)" % (fun.__name__, ",".join(args))
except Exception:
return "<%s>" % type(fun).__name__
def handle_command(self, command, args):
func = self._get_command(command, args)
retval = func(**args)
self.send_response(
response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)})
def handle_response(self, response, args):
pass
def main():
# Get proper serial class and port name
(serial, port) = get_serial()
logger.info("Opening port %s", port)
# Open serial connection
serial(SerialLineReceiver(), port, reactor)
try:
from notify import register_publisher
register_publisher(reactor)
except Exception:
logger.exception("Could not register notify publisher")
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor finished.")
if __name__ == '__main__':
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
if __name__ == '__main__':
main()
init_serial()
......@@ -3,28 +3,26 @@ import json
import logging
import platform
logger = logging.getLogger()
system = platform.system()
class SerialLineReceiverBase(LineReceiver, object):
MAX_LENGTH = 1024*1024*128
def __init__(self, *args, **kwargs):
if system == "FreeBSD":
self.delimiter = '\n'
if platform.system() == "FreeBSD":
self.delimiter = b'\n'
else:
self.delimiter = '\r'
self.delimiter = b'\r'
super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args):
self.transport.write(json.dumps({'response': response,
'args': args}) + '\r\n')
self.transport.write((json.dumps({'response': response,
'args': args}) + '\r\n').encode())
def send_command(self, command, args):
self.transport.write(json.dumps({'command': command,
'args': args}) + '\r\n')
self.transport.write((json.dumps({'command': command,
'args': args}) + '\r\n').encode())
def handle_command(self, command, args):
raise NotImplementedError("Subclass must implement abstract method")
......@@ -45,12 +43,12 @@ class SerialLineReceiverBase(LineReceiver, object):
logger.error('[serial] invalid json: %s (%s)' % (data, e))
return
if command is not None and isinstance(command, unicode):
if command is not None and isinstance(command, str):
logger.debug('received command: %s (%s)' % (command, args))
try:
self.handle_command(command, args)
except Exception as e:
logger.exception(u'Unhandled exception: ')
elif response is not None and isinstance(response, unicode):
elif response is not None and isinstance(response, str):
logger.debug('received reply: %s (%s)' % (response, args))
self.handle_response(response, args)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os import environ
import platform
from twisted.internet import reactor, defer
from twisted.internet.task import LoopingCall
import uptime
import logging
from inspect import getfullargspec, isfunction
from agent.SerialLineReceiverBase import SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext
# (relative import error.
from agent.context import BaseContext, get_context # noqa
Context = get_context()
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self):
self.transport.write(b'\r\n')
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': platform.system()})
def shutdown():
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
reactor.addSystemEventTrigger("before", "shutdown", shutdown)
def connectionLost(self, reason):
reactor.stop()
def connectionLost2(self, reason):
self.send_command(command='agent_stopped',
args={})
def tick(self):
logger.debug("Sending tick")
try:
self.send_status()
except Exception:
logger.exception("Twisted hide exception")
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.lc = LoopingCall(self.tick)
self.lc.start(5, now=False)
def send_status(self):
import psutil
disk_usage = dict((disk.device.replace('/', '_'),
psutil.disk_usage(disk.mountpoint).percent)
for disk in psutil.disk_partitions())
args = {"cpu": dict(psutil.cpu_times()._asdict()),
"ram": dict(psutil.virtual_memory()._asdict()),
"swap": dict(psutil.swap_memory()._asdict()),
"uptime": {"seconds": uptime.uptime()},
"disk": disk_usage,
"user": {"count": len(psutil.users())}}
self.send_response(response='status', args=args)
logger.debug("send_status finished")
def _check_args(self, func, args):
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
(self._pretty_fun(func), type(args).__name__))
# check for unexpected keyword arguments
argspec = getfullargspec(func)
if argspec.varkw is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs:
raise TypeError(
"Command %s got unexpected keyword arguments: %s" % (
self._pretty_fun(func), ", ".join(unexpected_kwargs)))
mandatory_args = argspec.args
if argspec.defaults: # remove those with default value
mandatory_args = mandatory_args[0:-len(argspec.defaults)]
missing_kwargs = set(mandatory_args) - set(args)
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
def _get_command(self, command, args):
if not isinstance(command, str) or command.startswith('_'):
raise AttributeError('Invalid command: %s' % command)
try:
func = getattr(Context, command)
except AttributeError as e:
raise AttributeError('Command not found: %s (%s)' % (command, e))
if not isfunction(func):
raise AttributeError("Command refers to non-static method %s." %
self._pretty_fun(func))
self._check_args(func, args)
logger.debug("_get_command finished")
return func
@staticmethod
def _pretty_fun(fun):
try:
argspec = getfullargspec(fun)
args = argspec.args
if argspec.varargs:
args.append("*" + argspec.varargs)
if argspec.varkw:
args.append("**" + argspec.varkw)
return "%s(%s)" % (fun.__name__, ",".join(args))
except Exception:
return "<%s>" % type(fun).__name__
def handle_command(self, command, args):
func = self._get_command(command, args)
retval = func(**args)
self.send_response(
response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)})
def handle_response(self, response, args):
pass
def init_serial():
# Get proper serial class and port name
(serial, port) = Context.get_serial()
logger.info("Opening port %s", port)
# Open serial connection
serial(SerialLineReceiver(), port, reactor)
try:
from agent.notify import register_publisher
register_publisher(reactor)
except Exception:
logger.exception("Could not register notify publisher")
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor finished.")
......@@ -4,71 +4,19 @@
import platform
def _get_virtio_device():
path = None
GUID = '{6FDE7521-1B65-48ae-B628-80BE62016026}'
from infi.devicemanager import DeviceManager
dm = DeviceManager()
dm.root.rescan()
# Search Virtio-Serial by name TODO: search by class_guid
for i in dm.all_devices:
if i.has_property("description"):
if "virtio-serial".upper() in i.description.upper():
path = ("\\\\?\\" +
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
return path
def get_context():
system = platform.system()
if system == "Windows":
from windows._win32context import Context
from agent.windows._win32context import Context
elif system == "Linux":
from linux._linuxcontext import Context
from agent.linux._linuxcontext import Context
elif system == "FreeBSD":
from freebsd._freebsdcontext import Context
from agent.freebsd._freebsdcontext import Context
else:
raise NotImplementedError("Platform %s is not supported.", system)
return Context
def get_serial():
system = platform.system()
port = None
if system == 'Windows':
port = _get_virtio_device()
import pythoncom
pythoncom.CoInitialize()
if port:
from windows.win32virtio import SerialPort
else:
from twisted.internet.serialport import SerialPort
port = r'\\.\COM1'
elif system == "Linux":
port = "/dev/virtio-ports/agent"
try:
open(port, 'rw').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyS0'
else:
from linux.posixvirtio import SerialPort
elif system == "FreeBSD":
port = "/dev/ttyV0.1"
try:
open(port, 'rw').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyu0'
else:
from freebsd.posixvirtio import SerialPort
else:
raise NotImplementedError("Platform %s is not supported.", system)
return (SerialPort, port)
class BaseContext(object):
@staticmethod
def change_password(password):
......@@ -136,5 +84,9 @@ class BaseContext(object):
@staticmethod
def send_expiration(url):
import notify
from agent import notify
notify.notify(url)
@staticmethod
def get_serial():
raise NotImplementedError()
......@@ -27,14 +27,14 @@ import fileinput
import tarfile
from os.path import expanduser, join, exists
from glob import glob
from StringIO import StringIO
from io import StringIO
from base64 import decodestring
from hashlib import md5
from ssh import PubKey
from .ssh import PubKey
from .network import change_ip_freebsd
from context import BaseContext
from agent.context import BaseContext
from twisted.internet import reactor
......@@ -69,7 +69,6 @@ class Context(BaseContext):
# python-module-to-change-system-date-and-time
@staticmethod
def _freebsd_set_time(time):
import ctypes
import ctypes.util
CLOCK_REALTIME = 0
......@@ -135,10 +134,10 @@ class Context(BaseContext):
# TODO
for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' smbfs ' in line):
print line.rstrip()
print(line.rstrip())
with open(NSMBRC, 'w') as f:
chmod(NSMBRC, 0600)
chmod(NSMBRC, 0o600)
f.write(nsmbrc_template_freebsd % data)
with open('/etc/fstab', 'a') as f:
......@@ -155,21 +154,21 @@ class Context(BaseContext):
try:
retval.append(PubKey.from_str(line))
except Exception:
logger.exception(u'Invalid ssh key: ')
logger.exception('Invalid ssh key: ')
except IOError:
pass
return retval
@staticmethod
def _save_keys(keys):
print keys
print(keys)
try:
mkdir(SSH_DIR)
except OSError:
pass
with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys:
f.write(unicode(key) + '\n')
f.write(str(key) + '\n')
@staticmethod
def add_keys(keys):
......@@ -180,7 +179,7 @@ class Context(BaseContext):
if p not in new_keys:
new_keys.append(p)
except Exception:
logger.exception(u'Invalid ssh key: ')
logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
......@@ -194,7 +193,7 @@ class Context(BaseContext):
except ValueError:
pass
except Exception:
logger.exception(u'Invalid ssh key: ')
logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
......@@ -280,5 +279,17 @@ class Context(BaseContext):
@staticmethod
def send_expiration(url):
import notify
from agent import notify
notify.notify(url)
@staticmethod
def get_serial():
port = "/dev/ttyV0.1"
try:
open(port, 'rw').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyu0'
else:
from agent.freebsd import SerialPort
return (SerialPort, port)
......@@ -37,7 +37,7 @@ def change_ip_freebsd(interfaces, nameservers):
data = list(get_interfaces_freebsd(interfaces))
for ifname, conf in data:
subprocess.call(('/usr/sbin/service', 'netif', 'stop', ifname))
remove_interfaces_freebsd(dict(data).keys())
remove_interfaces_freebsd(list(dict(data).keys()))
for device, conf in data:
if_file = rcconf_dir + "ifconfig_" + device
......
......@@ -37,10 +37,10 @@ class PubKey(object):
PubKey.validate_key(key_type, key)
self.key = key
self.comment = unicode(comment)
self.comment = str(comment)
def __hash__(self):
return hash(frozenset(self.__dict__.items()))
return hash(frozenset(list(self.__dict__.items())))
def __eq__(self, other):
return self.__dict__ == other.__dict__
......@@ -51,10 +51,10 @@ class PubKey(object):
return PubKey(key_type, key, comment)
def __unicode__(self):
return u' '.join((self.key_type, self.key, self.comment))
return ' '.join((self.key_type, self.key, self.comment))
def __repr__(self):
return u'<PubKey: %s>' % unicode(self)
return '<PubKey: %s>' % str(self)
# Unit tests
......@@ -85,11 +85,11 @@ class SshTestCase(unittest.TestCase):
def test_unicode(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_from_str(self):
p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_eq(self):
self.assertEqual(self.p1, self.p2)
......
......@@ -17,15 +17,12 @@ import fileinput
import tarfile
from os.path import expanduser, join, exists
from glob import glob
from StringIO import StringIO
from io import StringIO
from base64 import decodestring
from hashlib import md5
from ssh import PubKey
from .network import change_ip_ubuntu, change_ip_rhel
from context import BaseContext
from agent.linux.ssh import PubKey
from agent.linux.network import change_ip_ubuntu, change_ip_rhel
from agent.context import BaseContext
from twisted.internet import reactor
......@@ -56,7 +53,6 @@ class Context(BaseContext):
# python-module-to-change-system-date-and-time
@staticmethod
def _linux_set_time(time):
import ctypes
import ctypes.util
CLOCK_REALTIME = 0
......@@ -77,7 +73,7 @@ class Context(BaseContext):
def change_password(password):
proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password)
proc.communicate(('cloud:%s\n' % password).encode())
@staticmethod
def restart_networking():
......@@ -123,7 +119,7 @@ class Context(BaseContext):
# TODO
for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip()
print(line.rstrip())
with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data)
......@@ -139,7 +135,7 @@ class Context(BaseContext):
try:
retval.append(PubKey.from_str(line))
except Exception:
logger.exception(u'Invalid ssh key: ')
logger.exception('Invalid ssh key: ')
except IOError:
pass
return retval
......@@ -152,7 +148,7 @@ class Context(BaseContext):
pass
with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys:
f.write(unicode(key) + '\n')
f.write(str(key) + '\n')
uid = getpwnam("cloud").pw_uid
chown(SSH_DIR, uid, -1)
......@@ -167,7 +163,7 @@ class Context(BaseContext):
if p not in new_keys:
new_keys.append(p)
except Exception:
logger.exception(u'Invalid ssh key: ')
logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
......@@ -181,7 +177,7 @@ class Context(BaseContext):
except ValueError:
pass
except Exception:
logger.exception(u'Invalid ssh key: ')
logger.exception('Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
......@@ -263,5 +259,17 @@ class Context(BaseContext):
@staticmethod
def send_expiration(url):
import notify
from agent import notify
notify.notify(url)
@staticmethod
def get_serial():
port = "/dev/virtio-ports/agent"
try:
open(port, 'w').close()
except (OSError, IOError):
from twisted.internet.serialport import SerialPort
port = '/dev/ttyS0'
else:
from agent.linux.posixvirtio import SerialPort
return (SerialPort, port)
......@@ -27,7 +27,7 @@ def remove_interfaces_ubuntu(devices):
if line.startswith('#') or line == '' or line.isspace() or not words:
# keep line
print line
print(line)
continue
if (words[0] in ('auto', 'allow-hotplug') and
......@@ -49,7 +49,7 @@ def remove_interfaces_ubuntu(devices):
continue
# keep line
print line
print(line)
def change_ip_ubuntu(interfaces, nameservers):
......@@ -60,7 +60,7 @@ def change_ip_ubuntu(interfaces, nameservers):
subprocess.call(('/sbin/ip', 'addr', 'flush', 'dev', ifname))
subprocess.call(('/sbin/ip', 'link', 'set', 'dev', ifname,
'down'))
remove_interfaces_ubuntu(dict(data).keys())
remove_interfaces_ubuntu(list(dict(data).keys()))
with open(interfaces_file, 'a') as f:
for ifname, conf in data:
......
......@@ -37,10 +37,10 @@ class PubKey(object):
PubKey.validate_key(key_type, key)
self.key = key
self.comment = unicode(comment)
self.comment = str(comment)
def __hash__(self):
return hash(frozenset(self.__dict__.items()))
return hash(frozenset(list(self.__dict__.items())))
def __eq__(self, other):
return self.__dict__ == other.__dict__
......@@ -51,10 +51,10 @@ class PubKey(object):
return PubKey(key_type, key, comment)
def __unicode__(self):
return u' '.join((self.key_type, self.key, self.comment))
return ' '.join((self.key_type, self.key, self.comment))
def __repr__(self):
return u'<PubKey: %s>' % unicode(self)
return '<PubKey: %s>' % str(self)
# Unit tests
......@@ -85,11 +85,11 @@ class SshTestCase(unittest.TestCase):
def test_unicode(self):
p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_from_str(self):
p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
self.assertEqual(str(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')
def test_eq(self):
self.assertEqual(self.p1, self.p2)
......
......@@ -5,7 +5,7 @@
# Notify user about vm expiring
##
import cookielib
import http.cookiejar
import errno
import json
import logging
......@@ -13,8 +13,8 @@ import multiprocessing
import os
import platform
import subprocess
import urllib2
from urlparse import urlsplit
import urllib.request, urllib.error, urllib.parse
from urllib.parse import urlsplit
logger = logging.getLogger()
logger.debug("notify imported")
......@@ -58,19 +58,19 @@ def accept():
from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path):
print "There is no recent notification to accept."
print("There is no recent notification to accept.")
return False
# Load the saved url
url = json.load(open(file_path, "r"))
cj = cookielib.CookieJar()
opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(cj))
cj = http.cookiejar.CookieJar()
opener = urllib.request.build_opener(urllib.request.HTTPCookieProcessor(cj))
try:
opener.open(url) # GET to collect cookies
cookies = cj._cookies_for_request(urllib2.Request(url))
cookies = cj._cookies_for_request(urllib.request.Request(url))
token = [c for c in cookies if c.name == "csrftoken"][0].value
req = urllib2.Request(url, "", {
req = urllib.request.Request(url, "", {
"accept": "application/json", "referer": url,
"x-csrftoken": token})
rsp = opener.open(req)
......@@ -83,10 +83,10 @@ def accept():
new_local_time = parsed_time.astimezone(
get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e:
print "Parsing time failed: %s" % e
print("Parsing time failed: %s" % e)
except Exception as e:
print e
print "Renewal failed. Please try it manually at %s" % url
print(e)
print("Renewal failed. Please try it manually at %s" % url)
logger.exception("renew failed")
return False
else:
......@@ -102,7 +102,7 @@ def notify(url):
if win:
logger.info("notifying %d clients", len(clients))
for c in clients:
logger.debug("sending url %s to client %s", url, unicode(c))
logger.debug("sending url %s to client %s", url, str(c))
c.sendLine(url.encode())
else:
file_path = os.path.join(get_temp_dir(), file_name)
......@@ -209,11 +209,11 @@ if win:
self.factory = factory
def connectionMade(self):
logger.info("client connected: %s", unicode(self))
logger.info("client connected: %s", str(self))
clients.add(self)
def connectionLost(self, reason):
logger.info("client disconnected: %s", unicode(self))
logger.info("client disconnected: %s", str(self))
clients.remove(self)
class PubFactory(protocol.Factory):
......@@ -230,7 +230,7 @@ if win:
class SubProtocol(basic.LineReceiver):
def lineReceived(self, line):
print "received", line
print("received", line)
if line.startswith('cifs://'):
mount_smb(line)
else:
......@@ -243,7 +243,7 @@ if win:
def run_client():
from twisted.internet import reactor
print "connect to localhost:%d" % port
print("connect to localhost:%d" % port)
reactor.connectTCP("localhost", port, SubFactory())
reactor.run()
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
working_directory = r"C:\circle" # noqa
from os.path import join
import logging
import tarfile
from StringIO import StringIO
from base64 import decodestring
from tarfile import TarFile, ReadError
from io import StringIO
from base64 import b64decode
from hashlib import md5
from datetime import datetime
import win32api
import wmi
import netifaces
from twisted.internet import reactor
from .network import change_ip_windows
from context import BaseContext
from agent.windows.network import change_ip_windows
from agent.context import BaseContext
logger = logging.getLogger()
working_directory = r"C:\circle" # noqa
class Context(BaseContext):
......@@ -54,10 +48,10 @@ class Context(BaseContext):
@staticmethod
def mount_store(host, username, password):
import notify
from agent import notify
url = 'cifs://%s:%s@%s/%s' % (username, password, host, username)
for c in notify.clients:
logger.debug("sending url %s to client %s", url, unicode(c))
logger.debug("sending url %s to client %s", url, str(c))
c.sendLine(url.encode())
@staticmethod
......@@ -94,7 +88,7 @@ class Context(BaseContext):
@staticmethod
def _update_registry(dir, executable):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from _winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
from winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS)
with OpenKeyEx(HKEY_LOCAL_MACHINE,
r'SYSTEM\CurrentControlSet\services\circle-agent',
......@@ -111,11 +105,11 @@ class Context(BaseContext):
local_checksum = md5(data).hexdigest()
if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data))
decoded = StringIO(b64decode(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar = TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory)
except tarfile.ReadError as e:
except ReadError as e:
logger.error(e)
logger.info("Transfer completed!")
Context._update_registry(working_directory, executable)
......@@ -145,3 +139,32 @@ class Context(BaseContext):
return f.readline()
except IOError:
return None
@staticmethod
def _get_virtio_device():
path = None
GUID = '{6FDE7521-1B65-48ae-B628-80BE62016026}'
from infi.devicemanager import DeviceManager
dm = DeviceManager()
dm.root.rescan()
# Search Virtio-Serial by name TODO: search by class_guid
for i in dm.all_devices:
if i.has_property("description"):
if "virtio-serial".upper() in i.description.upper():
path = ("\\\\?\\" +
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
return path
@staticmethod
def get_serial():
port = Context._get_virtio_device()
import pythoncom
pythoncom.CoInitialize()
if port:
from agent.windows.win32virtio import SerialPort
else:
from twisted.internet.serialport import SerialPort
port = r'\\.\COM1'
return (SerialPort, port)
......@@ -81,9 +81,9 @@ class SerialPort(abstract.FileDescriptor):
except Exception:
import time
time.sleep(10)
n = 0
n = None
if n:
first = str(self.read_buf[:n])
first = bytes(self.read_buf[:n])
# now we should get everything that is already in the buffer (max
# 4096)
win32event.ResetEvent(self._overlappedRead.hEvent)
......@@ -95,7 +95,7 @@ class SerialPort(abstract.FileDescriptor):
self._overlappedRead,
1)
# handle all the received data:
self.protocol.dataReceived(first + str(buf[:n]))
self.protocol.dataReceived(first + bytes(buf[:n]))
# set up next one
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, self.read_buf = win32file.ReadFile(self.hComPort,
......
......@@ -11,7 +11,7 @@ fh.setFormatter(formatter)
logger.addHandler(fh)
from notify import run_client
from agent.notify import run_client
if __name__ == '__main__':
run_client()
Twisted==13.2.0
pyserial==2.7
psutil==5.4.8
uptime==3.0.1
netifaces==0.10.4
netaddr==0.7.12
Twisted
pyserial
psutil
uptime
netifaces
netaddr
infi.devicemanager
tzlocal
pytz
pywin32 ; sys_platform == 'win32'
wmi ; sys_platform == 'win32'
#!/usr/bin/env python
import notify
from agent import notify
if __name__ == '__main__':
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment