Commit ecbb709b by Szeberényi Imre

Python3 version

parent 6bdaacab
...@@ -8,6 +8,7 @@ import win32event ...@@ -8,6 +8,7 @@ import win32event
import win32service import win32service
import win32serviceutil import win32serviceutil
#import agent
from agent import main as agent_main, reactor from agent import main as agent_main, reactor
logger = logging.getLogger() logger = logging.getLogger()
...@@ -17,7 +18,7 @@ formatter = logging.Formatter( ...@@ -17,7 +18,7 @@ formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s") "%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO') level = os.environ.get('LOGLEVEL', 'DEBUG')
logger.setLevel(level) logger.setLevel(level)
logger.info("%s loaded", __file__) logger.info("%s loaded", __file__)
...@@ -46,7 +47,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -46,7 +47,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main(): def main():
if len(sys.argv) == 1: if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting... # service must be starting...
# for the sake of debugging etc, we use win32traceutil to see # for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements. # any unhandled exceptions and print statements.
......
...@@ -12,68 +12,106 @@ system = platform.system() # noqa ...@@ -12,68 +12,106 @@ system = platform.system() # noqa
if system == "Linux" or system == "FreeBSD": # noqa if system == "Linux" or system == "FreeBSD": # noqa
try: # noqa try: # noqa
chdir(sys.path[0]) # noqa chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa subprocess.call(('pip', 'install', '-r', 'requirements/linux.txt')) # noqa
except Exception: # noqa except Exception: # noqa
pass # hope it works # noqa pass # hope it works # noqa
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
import uptime import uptime
import logging import logging
from inspect import getargspec, isfunction from inspect import getargs, isfunction
from utils import SerialLineReceiverBase from utils import SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext # Note: Import everything because later we need to use the BaseContext
# (relative import error. # (relative import error.
from context import BaseContext, get_context, get_serial # noqa from context import BaseContext, get_context, get_serial # noqa
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
try:
from inspect import getfullargspec as getargspec
except ImportError:
from inspect import getargspec as getargspec
#############################################################
Context = get_context() Context = get_context()
logging.basicConfig() logging.basicConfig(
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] %(module)s.%(funcName)s:%(lineno)d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
)
logger = logging.getLogger() logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO') level = environ.get('LOGLEVEL', 'DEBUG')
logger.setLevel(level) logger.setLevel(level)
class SerialLineReceiver(SerialLineReceiverBase): class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self): def __init__(self):
self.transport.write('\r\n') super(SerialLineReceiver, self).__init__()
self.send_command( self.tickId = LoopingCall(self.tick)
command='agent_started', self.mayStartNowId = LoopingCall(self.mayStartNow)
args={'version': Context.get_agent_version(), reactor.addSystemEventTrigger("before", "shutdown", self.shutdown)
'system': system}) self.running = True
def shutdown():
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
reactor.addSystemEventTrigger("before", "shutdown", shutdown)
def connectionMade(self):
logger.debug("connectionMade")
self.clearLineBuffer()
self.tickId.start(5, now=False)
self.mayStartNowId.start(10, now=False)
self.send_startMsg()
def connectionLost(self, reason): def connectionLost(self, reason):
reactor.stop() logger.debug("connectionLost")
if self.tickId.running:
self.tickId.stop()
if self.mayStartNowId.running:
self.mayStartNowId.stop()
def connectionLost2(self, reason): def connectionLost2(self, reason):
self.send_command(command='agent_stopped', self.send_command(command='agent_stopped', args={})
args={})
def mayStartNow(self):
if BaseContext.placed:
self.mayStartNowId.stop()
logger.info("Placed")
return
self.send_startMsg()
def tick(self): def tick(self):
logger.debug("Sending tick") logger.debug("Sending tick")
try: try:
self.send_status() self.send_status()
except Exception: except Exception as e:
logger.exception("Twisted hide exception") logger.debug("Exception durig tick: %s" % e)
# logger.exception("Twisted hide exception")
def __init__(self):
super(SerialLineReceiver, self).__init__() def shutdown(self):
self.lc = LoopingCall(self.tick) self.running = False
self.lc.start(5, now=False) logger.debug("shutdown")
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
def send_startMsg(self):
logger.debug("Sending start message...")
# Hack for flushing the lower level buffersr
self.transport.dataBuffer = b""
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite
self.transport._tempDataLen = 0
self.transport.write('\r\n')
if self.running:
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': system})
def send_status(self): def send_status(self):
import psutil import psutil
...@@ -90,6 +128,7 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -90,6 +128,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
logger.debug("send_status finished") logger.debug("send_status finished")
def _check_args(self, func, args): def _check_args(self, func, args):
logger.debug("_check_args %s %s" % (func, args))
if not isinstance(args, dict): if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a " raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." % "dict for command %s instead of %s." %
...@@ -97,7 +136,11 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -97,7 +136,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
# check for unexpected keyword arguments # check for unexpected keyword arguments
argspec = getargspec(func) argspec = getargspec(func)
if argspec.keywords is None: # _operation doesn't take ** args try:
_kwargs = argspec.keywords
except AttributeError:
_kwargs = argspec.varkw
if _kwargs is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args) unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs: if unexpected_kwargs:
raise TypeError( raise TypeError(
...@@ -111,9 +154,11 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -111,9 +154,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
if missing_kwargs: if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % ( raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs))) self._pretty_fun(func), ", ".join(missing_kwargs)))
logger.debug("_check_args finished")
def _get_command(self, command, args): def _get_command(self, command, args):
if not isinstance(command, basestring) or command.startswith('_'): logger.debug("_get_command %s %s" % (command, args))
if not isinstance(command, unicode) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command) raise AttributeError(u'Invalid command: %s' % command)
try: try:
func = getattr(Context, command) func = getattr(Context, command)
...@@ -142,8 +187,11 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -142,8 +187,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
return "<%s>" % type(fun).__name__ return "<%s>" % type(fun).__name__
def handle_command(self, command, args): def handle_command(self, command, args):
logger.debug("handle_command %s %s" % (command, args))
func = self._get_command(command, args) func = self._get_command(command, args)
logger.debug("Call cmd: %s %s" % (func, args))
retval = func(**args) retval = func(**args)
logger.debug("Retval: %s" % retval)
self.send_response( self.send_response(
response=func.__name__, response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)}) args={'retval': retval, 'uuid': args.get('uuid', None)})
...@@ -151,7 +199,6 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -151,7 +199,6 @@ class SerialLineReceiver(SerialLineReceiverBase):
def handle_response(self, response, args): def handle_response(self, response, args):
pass pass
def main(): def main():
# Get proper serial class and port name # Get proper serial class and port name
(serial, port) = get_serial() (serial, port) = get_serial()
......
""" This is the defautl context file. It replaces the Context class """ This is the defautl context file. It replaces the Context class
to the platform specific one. to the platform specific one.
""" """
import logging
import platform import platform
logger = logging.getLogger()
def _get_virtio_device(): def _get_virtio_device():
path = None path = None
...@@ -18,6 +20,8 @@ def _get_virtio_device(): ...@@ -18,6 +20,8 @@ def _get_virtio_device():
i.children[0].instance_id.lower().replace('\\', '#') + i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower() "#" + GUID.lower()
) )
break
logger.debug("DEV found: %s", path)
return path return path
...@@ -36,6 +40,7 @@ def get_context(): ...@@ -36,6 +40,7 @@ def get_context():
def get_serial(): def get_serial():
system = platform.system() system = platform.system()
logger.debug("Get_serial system: %s", system)
port = None port = None
if system == 'Windows': if system == 'Windows':
port = _get_virtio_device() port = _get_virtio_device()
...@@ -49,8 +54,10 @@ def get_serial(): ...@@ -49,8 +54,10 @@ def get_serial():
elif system == "Linux": elif system == "Linux":
port = "/dev/virtio-ports/agent" port = "/dev/virtio-ports/agent"
try: try:
open(port, 'rw').close() print("Open!")
except (OSError, IOError): open(port, 'r').close()
except (OSError, IOError) as e:
print(e)
from twisted.internet.serialport import SerialPort from twisted.internet.serialport import SerialPort
port = '/dev/ttyS0' port = '/dev/ttyS0'
else: else:
...@@ -70,6 +77,8 @@ def get_serial(): ...@@ -70,6 +77,8 @@ def get_serial():
class BaseContext(object): class BaseContext(object):
placed = False # if we reciwed password or net commands
@staticmethod @staticmethod
def change_password(password): def change_password(password):
pass pass
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
from os import mkdir, remove, chown from os import mkdir, remove, chown
from pwd import getpwnam from pwd import getpwnam
import platform try:
from distro import distro as platform
except ImportError:
import platform
from shutil import rmtree, move from shutil import rmtree, move
import subprocess import subprocess
import sys import sys
...@@ -17,12 +21,21 @@ import fileinput ...@@ -17,12 +21,21 @@ import fileinput
import tarfile import tarfile
from os.path import expanduser, join, exists from os.path import expanduser, join, exists
from glob import glob from glob import glob
from StringIO import StringIO
from base64 import decodestring try:
from StringIO import StringIO
except ImportError:
from io import StringIO
try:
from base64 import decodestring
except ImportError:
from base64 import decodebytes as decodestring
from hashlib import md5 from hashlib import md5
from ssh import PubKey from .ssh import PubKey
from .network import change_ip_ubuntu, change_ip_rhel from .network import change_ip_ubuntu, change_ip_rhel
from context import BaseContext from context import BaseContext
...@@ -77,7 +90,8 @@ class Context(BaseContext): ...@@ -77,7 +90,8 @@ class Context(BaseContext):
def change_password(password): def change_password(password):
proc = subprocess.Popen(['/usr/sbin/chpasswd'], proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE) stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password) proc.communicate(str.encode('cloud:%s\n' % password))
BaseContext.placed = True
@staticmethod @staticmethod
def restart_networking(): def restart_networking():
...@@ -90,10 +104,12 @@ class Context(BaseContext): ...@@ -90,10 +104,12 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_ip(interfaces, dns): def change_ip(interfaces, dns):
nameservers = dns.replace(' ', '').split(',') nameservers = dns.replace(' ', '').split(',')
logger.debug("Distro: %s" % distro)
if distro == 'debian': if distro == 'debian':
change_ip_ubuntu(interfaces, nameservers) change_ip_ubuntu(interfaces, nameservers)
elif distro == 'rhel': elif distro == 'rhel':
change_ip_rhel(interfaces, nameservers) change_ip_rhel(interfaces, nameservers)
BaseContext.placed = True
@staticmethod @staticmethod
def set_time(time): def set_time(time):
...@@ -123,7 +139,7 @@ class Context(BaseContext): ...@@ -123,7 +139,7 @@ class Context(BaseContext):
# TODO # TODO
for line in fileinput.input('/etc/fstab', inplace=True): for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line): if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip() print(line.rstrip())
with open('/etc/fstab', 'a') as f: with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data) f.write(mount_template_linux % data)
......
...@@ -27,7 +27,7 @@ def remove_interfaces_ubuntu(devices): ...@@ -27,7 +27,7 @@ def remove_interfaces_ubuntu(devices):
if line.startswith('#') or line == '' or line.isspace() or not words: if line.startswith('#') or line == '' or line.isspace() or not words:
# keep line # keep line
print line print(line)
continue continue
if (words[0] in ('auto', 'allow-hotplug') and if (words[0] in ('auto', 'allow-hotplug') and
...@@ -49,7 +49,7 @@ def remove_interfaces_ubuntu(devices): ...@@ -49,7 +49,7 @@ def remove_interfaces_ubuntu(devices):
continue continue
# keep line # keep line
print line print(line)
def change_ip_ubuntu(interfaces, nameservers): def change_ip_ubuntu(interfaces, nameservers):
......
...@@ -34,9 +34,16 @@ class SerialPort(abstract.FileDescriptor): ...@@ -34,9 +34,16 @@ class SerialPort(abstract.FileDescriptor):
self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK) self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK)
self.reactor = reactor self.reactor = reactor
self.protocol = protocol self.protocol = protocol
self.connected = 1
self.protocol.makeConnection(self) self.protocol.makeConnection(self)
self.startReading() self.startReading()
def write(self, data):
if data:
if isinstance(data, str):
data = str.encode(data)
self.writeSomeData(data)
def fileno(self): def fileno(self):
return self._serial return self._serial
...@@ -60,9 +67,12 @@ class SerialPort(abstract.FileDescriptor): ...@@ -60,9 +67,12 @@ class SerialPort(abstract.FileDescriptor):
serial data. serial data.
""" """
abstract.FileDescriptor.connectionLost(self, reason) abstract.FileDescriptor.connectionLost(self, reason)
self.protocol.connectionLost(reason)
os.close(self._serial) os.close(self._serial)
sleep(2) logger.debug("Reconecting after 5s")
sleep(5)
self._serial = os.open( self._serial = os.open(
self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK) self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK)
self.connected = 1
self.protocol.makeConnection(self)
self.startReading() self.startReading()
logger.info("Reconnecting")
from base64 import decodestring
try:
from base64 import decodestring
except ImportError:
from base64 import decodebytes as decodestring
from struct import unpack from struct import unpack
import binascii import binascii
import unittest import unittest
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# Notify user about vm expiring # Notify user about vm expiring
## ##
import cookielib
import errno import errno
import json import json
import logging import logging
...@@ -13,8 +12,27 @@ import multiprocessing ...@@ -13,8 +12,27 @@ import multiprocessing
import os import os
import platform import platform
import subprocess import subprocess
import urllib2
from urlparse import urlsplit try:
import cookielib
except ImportError:
import http.cookiejar as cookielib
try:
import urllib2
except ImportError:
import urllib.request as urllib2
try:
from urlparse import urlsplit
except ImportError:
from urllib.parse import urlsplit
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger() logger = logging.getLogger()
logger.debug("notify imported") logger.debug("notify imported")
...@@ -43,6 +61,8 @@ def get_temp_dir(): ...@@ -43,6 +61,8 @@ def get_temp_dir():
def wall(text): def wall(text):
if isinstance(text, str):
text = str.encode(text)
if win: if win:
return return
if text is None: if text is None:
...@@ -58,7 +78,7 @@ def accept(): ...@@ -58,7 +78,7 @@ def accept():
from pytz import UTC from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name) file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
print "There is no recent notification to accept." print("There is no recent notification to accept.")
return False return False
# Load the saved url # Load the saved url
...@@ -70,9 +90,9 @@ def accept(): ...@@ -70,9 +90,9 @@ def accept():
opener.open(url) # GET to collect cookies opener.open(url) # GET to collect cookies
cookies = cj._cookies_for_request(urllib2.Request(url)) cookies = cj._cookies_for_request(urllib2.Request(url))
token = [c for c in cookies if c.name == "csrftoken"][0].value token = [c for c in cookies if c.name == "csrftoken"][0].value
req = urllib2.Request(url, "", { req = urllib2.Request(url, b"", {
"accept": "application/json", "referer": url, b"accept": b"application/json", b"referer": url,
"x-csrftoken": token}) b"x-csrftoken": token})
rsp = opener.open(req) rsp = opener.open(req)
data = json.load(rsp) data = json.load(rsp)
newtime = data["new_suspend_time"] newtime = data["new_suspend_time"]
...@@ -83,10 +103,10 @@ def accept(): ...@@ -83,10 +103,10 @@ def accept():
new_local_time = parsed_time.astimezone( new_local_time = parsed_time.astimezone(
get_localzone()).strftime("%Y-%m-%d %H:%M:%S") get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e: except ValueError as e:
print "Parsing time failed: %s" % e print("Parsing time failed: %s" % e)
except Exception as e: except Exception as e:
print e print(e)
print "Renewal failed. Please try it manually at %s" % url print("Renewal failed. Please try it manually at %s" % url)
logger.exception("renew failed") logger.exception("renew failed")
return False return False
else: else:
...@@ -144,6 +164,7 @@ def open_in_browser(url): ...@@ -144,6 +164,7 @@ def open_in_browser(url):
def mount_smb(url): def mount_smb(url):
data = urlsplit(url) data = urlsplit(url)
share = data.path.lstrip('/') share = data.path.lstrip('/')
print("host: %s share %s user: %s pw: %s" % (data.hostname, share, data.username, data.password))
subprocess.call(('net', 'use', 'Z:', '/delete')) subprocess.call(('net', 'use', 'Z:', '/delete'))
try: try:
p = subprocess.Popen(( p = subprocess.Popen((
...@@ -228,9 +249,20 @@ if win: ...@@ -228,9 +249,20 @@ if win:
reactor.listenTCP(port, PubFactory(), interface='localhost') reactor.listenTCP(port, PubFactory(), interface='localhost')
class SubProtocol(basic.LineReceiver): class SubProtocol(basic.LineReceiver):
def connectionMade(self):
logger.info("Subclient connected: %s", unicode(self))
clients.add(self)
def connectionLost(self, reason):
logger.info("Subclient disconnected: %s", unicode(self))
clients.remove(self)
def lineReceived(self, line): def lineReceived(self, line):
print "received", line logger.debug("received %s %s" % (line, type(line)))
if not isinstance(line, str):
line = line.decode()
if line.startswith('cifs://'): if line.startswith('cifs://'):
mount_smb(line) mount_smb(line)
else: else:
...@@ -243,7 +275,7 @@ if win: ...@@ -243,7 +275,7 @@ if win:
def run_client(): def run_client():
from twisted.internet import reactor from twisted.internet import reactor
print "connect to localhost:%d" % port print("connect to localhost:%d" % port)
reactor.connectTCP("localhost", port, SubFactory()) reactor.connectTCP("localhost", port, SubFactory())
reactor.run() reactor.run()
......
Twisted
pyserial
psutil
uptime
netifaces
netaddr
infi.devicemanager
tzlocal
pytz
distro
Twisted
pyserial
psutil
uptime
netifaces
netaddr
infi.devicemanager
tzlocal
pytz
pywin32
wmi
...@@ -3,6 +3,11 @@ import json ...@@ -3,6 +3,11 @@ import json
import logging import logging
import platform import platform
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger() logger = logging.getLogger()
system = platform.system() system = platform.system()
...@@ -13,12 +18,13 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -13,12 +18,13 @@ class SerialLineReceiverBase(LineReceiver, object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if system == "FreeBSD": if system == "FreeBSD":
self.delimiter = '\n' self.delimiter = b'\n'
else: else:
self.delimiter = '\r' self.delimiter = b'\r'
super(SerialLineReceiverBase, self).__init__(*args, **kwargs) super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args): def send_response(self, response, args):
# logger.debug("send_response %s %s" % (response, args))
self.transport.write(json.dumps({'response': response, self.transport.write(json.dumps({'response': response,
'args': args}) + '\r\n') 'args': args}) + '\r\n')
...@@ -33,6 +39,11 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -33,6 +39,11 @@ class SerialLineReceiverBase(LineReceiver, object):
raise NotImplementedError("Subclass must implement abstract method") raise NotImplementedError("Subclass must implement abstract method")
def lineReceived(self, data): def lineReceived(self, data):
logger.debug("lineReceived: %s", data)
if (isinstance(data, unicode)):
data = data.strip('\0')
else:
data = data.strip(b'\0')
try: try:
data = json.loads(data) data = json.loads(data)
args = data.get('args', {}) args = data.get('args', {})
...@@ -43,6 +54,7 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -43,6 +54,7 @@ class SerialLineReceiverBase(LineReceiver, object):
logger.debug('[serial] valid json: %s' % (data, )) logger.debug('[serial] valid json: %s' % (data, ))
except (ValueError, KeyError) as e: except (ValueError, KeyError) as e:
logger.error('[serial] invalid json: %s (%s)' % (data, e)) logger.error('[serial] invalid json: %s (%s)' % (data, e))
self.clearLineBuffer()
return return
if command is not None and isinstance(command, unicode): if command is not None and isinstance(command, unicode):
...@@ -50,7 +62,8 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -50,7 +62,8 @@ class SerialLineReceiverBase(LineReceiver, object):
try: try:
self.handle_command(command, args) self.handle_command(command, args)
except Exception as e: except Exception as e:
logger.exception(u'Unhandled exception: ') logger.exception("Unhandled exception during line recived: ")
elif response is not None and isinstance(response, unicode): elif response is not None and isinstance(response, unicode):
logger.debug('received reply: %s (%s)' % (response, args)) logger.debug('received reply: %s (%s)' % (response, args))
self.clearLineBuffer()
self.handle_response(response, args) self.handle_response(response, args)
...@@ -65,7 +65,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -65,7 +65,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main(): def main():
if len(sys.argv) == 1: if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting... # service must be starting...
# for the sake of debugging etc, we use win32traceutil to see # for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements. # any unhandled exceptions and print statements.
......
...@@ -7,8 +7,8 @@ from os.path import join ...@@ -7,8 +7,8 @@ from os.path import join
import logging import logging
import tarfile import tarfile
from StringIO import StringIO from io import StringIO
from base64 import decodestring from base64 import b64decode
from hashlib import md5 from hashlib import md5
from datetime import datetime from datetime import datetime
import win32api import win32api
...@@ -20,7 +20,12 @@ from twisted.internet import reactor ...@@ -20,7 +20,12 @@ from twisted.internet import reactor
from .network import change_ip_windows from .network import change_ip_windows
from context import BaseContext from context import BaseContext
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger() logger = logging.getLogger()
...@@ -28,6 +33,7 @@ class Context(BaseContext): ...@@ -28,6 +33,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_password(password): def change_password(password):
BaseContext.placed = True
from win32com import adsi from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud') ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo() ads_obj.Getinfo()
...@@ -39,6 +45,7 @@ class Context(BaseContext): ...@@ -39,6 +45,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_ip(interfaces, dns): def change_ip(interfaces, dns):
BaseContext.placed = True
nameservers = dns.replace(' ', '').split(',') nameservers = dns.replace(' ', '').split(',')
change_ip_windows(interfaces, nameservers) change_ip_windows(interfaces, nameservers)
...@@ -111,7 +118,7 @@ class Context(BaseContext): ...@@ -111,7 +118,7 @@ class Context(BaseContext):
local_checksum = md5(data).hexdigest() local_checksum = md5(data).hexdigest()
if local_checksum != checksum: if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.") raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data)) decoded = StringIO(b64decode(data))
try: try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz') tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory) tar.extractall(working_directory)
......
...@@ -7,12 +7,14 @@ Serial port support for Windows. ...@@ -7,12 +7,14 @@ Serial port support for Windows.
Requires PySerial and pywin32. Requires PySerial and pywin32.
""" """
# system imports
import win32file import win32file
import win32event import win32event
import win32con import win32con
# system imports
from serial.serialutil import to_bytes # type: ignore[import]
from time import sleep
# twisted imports # twisted imports
from twisted.internet import abstract from twisted.internet import abstract
...@@ -27,9 +29,13 @@ class SerialPort(abstract.FileDescriptor): ...@@ -27,9 +29,13 @@ class SerialPort(abstract.FileDescriptor):
connected = 1 connected = 1
def __init__(self, protocol, deviceNameOrPortNumber, reactor): def __init__(self, protocol, deviceName, reactor):
self.initHard(protocol, deviceName, reactor)
self.initSoft(protocol, deviceName, reactor)
def initHard(self, protocol, deviceName, reactor):
self.hComPort = win32file.CreateFile( self.hComPort = win32file.CreateFile(
deviceNameOrPortNumber, deviceName,
win32con.GENERIC_READ | win32con.GENERIC_WRITE, win32con.GENERIC_READ | win32con.GENERIC_WRITE,
0, # exclusive access 0, # exclusive access
None, # no security None, # no security
...@@ -38,101 +44,103 @@ class SerialPort(abstract.FileDescriptor): ...@@ -38,101 +44,103 @@ class SerialPort(abstract.FileDescriptor):
0) 0)
self.reactor = reactor self.reactor = reactor
self.protocol = protocol self.protocol = protocol
self.outQueue = [] self.deviceName = deviceName
self.closed = 0
self.closedNotifies = 0
self.writeInProgress = 0
self.protocol = protocol
self._overlappedRead = win32file.OVERLAPPED() self._overlappedRead = win32file.OVERLAPPED()
self._overlappedRead.hEvent = win32event.CreateEvent(None, 1, 0, None) self._overlappedRead.hEvent = win32event.CreateEvent(None, 1, 0, None)
self._overlappedWrite = win32file.OVERLAPPED() self._overlappedWrite = win32file.OVERLAPPED()
self._overlappedWrite.hEvent = win32event.CreateEvent(None, 0, 0, None) self._overlappedWrite.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.reactor.addEvent(self._overlappedRead.hEvent, self, 'serialReadEvent')
self.reactor.addEvent(self._overlappedWrite.hEvent, self, 'serialWriteEvent')
self.reactor.addEvent( def initSoft(self, protocol, deviceName, reactor):
self._overlappedRead.hEvent, self.outQueue = []
self, self.closed = 0
'serialReadEvent') self.closedNotifies = 0
self.reactor.addEvent( self.writeInProgress = 0
self._overlappedWrite.hEvent, self.conneted = 1
self, self._reconnInProgress = False
'serialWriteEvent')
self.protocol.makeConnection(self) self.protocol.makeConnection(self)
self._finishPortSetup() self._startReading()
def _finishPortSetup(self):
"""
Finish setting up the serial port.
This is a separate method to facilitate testing. def _startReading(self, len=4096):
"""
rc, self.read_buf = win32file.ReadFile(self.hComPort, rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1), win32file.AllocateReadBuffer(len),
self._overlappedRead) self._overlappedRead)
def serialReadEvent(self): def serialReadEvent(self):
# get that character we set up logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
try: try:
n = win32file.GetOverlappedResult( n = win32file.GetOverlappedResult(self.hComPort, self._overlappedRead, 1)
self.hComPort, except Exception as e:
self._overlappedRead, logger.debug("Exception %s" % e)
0) sleep(10)
except Exception: logger.debug(self.connLost())
import time
time.sleep(10)
n = 0 n = 0
if n: if n > 0:
first = str(self.read_buf[:n]) # handle the received data:
# now we should get everything that is already in the buffer (max self.protocol.dataReceived(to_bytes(self.read_buf[:n]))
# 4096) # set up next read
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(4096),
self._overlappedRead)
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
1)
# handle all the received data:
self.protocol.dataReceived(first + str(buf[:n]))
# set up next one
win32event.ResetEvent(self._overlappedRead.hEvent) win32event.ResetEvent(self._overlappedRead.hEvent)
rc, self.read_buf = win32file.ReadFile(self.hComPort, self._startReading()
win32file.AllocateReadBuffer(1),
self._overlappedRead)
def write(self, data): def write(self, data):
if data: if data:
if isinstance(data, str):
data = str.encode(data)
if self.writeInProgress: if self.writeInProgress:
self.outQueue.append(data) self.outQueue.append(data)
logger.debug("added to queue") logger.debug("added to queue")
else: else:
self.writeInProgress = 1 self.writeInProgress = 1
win32file.WriteFile(self.hComPort, data, self._overlappedWrite) ret, n = win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file") logger.debug("Writed to file %s", ret)
def serialWriteEvent(self): def serialWriteEvent(self):
logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
if self._overlappedWrite.Internal < 0 and self._overlappedWrite.InternalHigh == 0 : # DANGER: Not documented variables
logger.debug(self.connLost())
self.writeInProgress = 0
return
try: try:
dataToWrite = self.outQueue.pop(0) dataToWrite = self.outQueue.pop(0)
except IndexError: except IndexError:
self.writeInProgress = 0 self.writeInProgress = 0
return return
else: else:
win32file.WriteFile( win32file.WriteFile(self.hComPort, dataToWrite, self._overlappedWrite)
self.hComPort,
dataToWrite, def connLost(self):
self._overlappedWrite) if self._reconnInProgress:
return None
self._reconnInProgress = True
return self.reactor.callLater(30, self.connectionLostEvent, self)
def connectionLostEvent(self, reason):
abstract.FileDescriptor.connectionLost(self, reason)
self.protocol.connectionLost(reason)
logger.debug("Reconecting after 30s")
# sleep(30)
self.initSoft(self.protocol, self.deviceName, self.reactor)
def connectionLost(self, reason): def connectionLost(self, reason=None):
""" """
Called when the serial port disconnects. Called when the serial port disconnects.
Will call C{connectionLost} on the protocol that is handling the Will call C{connectionLost} on the protocol that is handling the
serial data. serial data.
""" """
# import pdb; pdb.set_trace()
win32file.CancelIo(self.hComPort)
self.reactor.removeEvent(self._overlappedRead.hEvent) self.reactor.removeEvent(self._overlappedRead.hEvent)
self.reactor.removeEvent(self._overlappedWrite.hEvent) self.reactor.removeEvent(self._overlappedWrite.hEvent)
win32file.CloseHandle(self._overlappedRead.hEvent)
win32file.CloseHandle(self._overlappedWrite.hEvent)
abstract.FileDescriptor.connectionLost(self, reason) abstract.FileDescriptor.connectionLost(self, reason)
win32file.CloseHandle(self.hComPort) win32file.CloseHandle(self.hComPort)
self.protocol.connectionLost(reason) self.protocol.connectionLost(reason)
logger.debug("Hard reconecting after 10s")
sleep(10)
self.initHard(self.protocol, self.deviceName, self.reactor)
self.initSoft(self.protocol, self.deviceName, self.reactor)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment