Commit e69aa084 by Szeberényi Imre

fix: 2<->3, win-virtio, new uppdate scheme

parent ba6ba173
......@@ -8,6 +8,7 @@ import win32event
import win32service
import win32serviceutil
#import agent
from agent import main as agent_main, reactor
logger = logging.getLogger()
......@@ -17,7 +18,7 @@ formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO')
level = os.environ.get('LOGLEVEL', 'DEBUG')
logger.setLevel(level)
logger.info("%s loaded", __file__)
......@@ -46,7 +47,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main():
if len(sys.argv) == 1:
if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting...
# for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements.
......
......@@ -6,24 +6,13 @@ import platform
import subprocess
import sys
system = platform.system() # noqa
if system == "Linux" or system == "FreeBSD": # noqa
try: # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
from twisted.internet import reactor, defer
from twisted.internet.task import LoopingCall
import uptime
import logging
from inspect import getargspec, isfunction
from inspect import getargs, isfunction
from utils import SerialLineReceiverBase
......@@ -31,9 +20,37 @@ from utils import SerialLineReceiverBase
# (relative import error.
from context import BaseContext, get_context, get_serial # noqa
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
try:
from inspect import getfullargspec as getargspec
except ImportError:
from inspect import getargspec as getargspec
def foo(a, *args, **kwargs):
pass
system = platform.system() # noqa
if system == "Linux" or system == "FreeBSD": # noqa
try: # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
Context = get_context()
logging.basicConfig()
logging.basicConfig(
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] %(module)s.%(funcName)s:%(lineno)d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
)
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
......@@ -45,35 +62,42 @@ class SerialLineReceiver(SerialLineReceiverBase):
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.tickId = LoopingCall(self.tick)
self.mayStartNowId = LoopingCall(self.mayStartNow)
reactor.addSystemEventTrigger("before", "shutdown", self.shutdown)
self.running = True
def connectionMade(self):
logger.debug("connectionMade")
self.clearLineBuffer()
self.tickId.start(5, now=False)
self.transport.dataBuffer = b""
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite
self.transport._tempDataLen = 0
self.transport.write('\r\n')
if self.running:
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': system})
self.mayStartNowId.start(10, now=False)
self.send_startMsg()
def connectionLost(self, reason):
logger.debug("connectionLost")
if self.tickId.running:
self.tickId.stop()
if self.mayStartNowId.running:
self.mayStartNowId.stop()
def connectionLost2(self, reason):
self.send_command(command='agent_stopped',
args={})
def mayStartNow(self):
if BaseContext.placed:
self.mayStartNowId.stop()
logger.info("Placed")
return
self.send_startMsg()
def tick(self):
logger.debug("Sending tick")
try:
self.send_status()
except Exception:
logger.exception("Twisted hide exception")
except Exception as e:
logger.debug("Exception durig tick: %s" % e)
# logger.exception("Twisted hide exception")
def shutdown(self):
self.running = False
......@@ -83,6 +107,18 @@ class SerialLineReceiver(SerialLineReceiverBase):
reactor.callLater(0.3, d.callback, "1")
return d
def send_startMsg(self):
logger.debug("Sending start message...")
# Hack for flushing the lower level buffersr
self.transport.dataBuffer = b""
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite
self.transport._tempDataLen = 0
self.transport.write('\r\n')
if self.running:
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': system})
def send_status(self):
import psutil
disk_usage = dict((disk.device.replace('/', '_'),
......@@ -98,6 +134,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
logger.debug("send_status finished")
def _check_args(self, func, args):
logger.debug("_check_args %s %s" % (func, args))
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
......@@ -105,7 +142,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
# check for unexpected keyword arguments
argspec = getargspec(func)
if argspec.keywords is None: # _operation doesn't take ** args
try:
_kwargs = argspec.keywords
except AttributeError:
_kwargs = argspec.varkw
if _kwargs is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs:
raise TypeError(
......@@ -119,10 +160,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
logger.debug("_check_args finished")
def _get_command(self, command, args):
logger.debug("_get_command %s %s" % (command, args))
if not isinstance(command, basestring) or command.startswith('_'):
if not isinstance(command, unicode) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command)
try:
func = getattr(Context, command)
......@@ -153,7 +195,9 @@ class SerialLineReceiver(SerialLineReceiverBase):
def handle_command(self, command, args):
logger.debug("handle_command %s %s" % (command, args))
func = self._get_command(command, args)
logger.debug("Call cmd: %s %s" % (func, args))
retval = func(**args)
logger.debug("Retval: %s" % retval)
self.send_response(
response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)})
......
""" This is the defautl context file. It replaces the Context class
to the platform specific one.
"""
import logging
import platform
logger = logging.getLogger()
def _get_virtio_device():
path = None
......@@ -18,6 +20,8 @@ def _get_virtio_device():
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
break
logger.debug("DEV found: %s", path)
return path
......@@ -36,6 +40,7 @@ def get_context():
def get_serial():
system = platform.system()
logger.debug("Get_serial system: %s", system)
port = None
if system == 'Windows':
port = _get_virtio_device()
......@@ -70,6 +75,8 @@ def get_serial():
class BaseContext(object):
placed = False # if we reciwed password or net commands
@staticmethod
def change_password(password):
pass
......
......@@ -5,7 +5,6 @@
# Notify user about vm expiring
##
import cookielib
import errno
import json
import logging
......@@ -13,8 +12,27 @@ import multiprocessing
import os
import platform
import subprocess
import urllib2
from urlparse import urlsplit
try:
import cookielib
except ImportError:
import http.cookiejar as cookielib
try:
import urllib2
except ImportError:
import urllib.request as urllib2
try:
from urlparse import urlsplit
except ImportError:
from urllib.parse import urlsplit
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger()
logger.debug("notify imported")
......@@ -58,7 +76,7 @@ def accept():
from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path):
print "There is no recent notification to accept."
print("There is no recent notification to accept.")
return False
# Load the saved url
......@@ -83,10 +101,10 @@ def accept():
new_local_time = parsed_time.astimezone(
get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e:
print "Parsing time failed: %s" % e
print("Parsing time failed: %s" % e)
except Exception as e:
print e
print "Renewal failed. Please try it manually at %s" % url
print(e)
print("Renewal failed. Please try it manually at %s" % url)
logger.exception("renew failed")
return False
else:
......@@ -144,6 +162,7 @@ def open_in_browser(url):
def mount_smb(url):
data = urlsplit(url)
share = data.path.lstrip('/')
print("host: %s share %s user: %s pw: %s" % (data.hostname, share, data.username, data.password))
subprocess.call(('net', 'use', 'Z:', '/delete'))
try:
p = subprocess.Popen((
......@@ -229,8 +248,19 @@ if win:
class SubProtocol(basic.LineReceiver):
def connectionMade(self):
logger.info("Subclient connected: %s", unicode(self))
clients.add(self)
def connectionLost(self, reason):
logger.info("Subclient disconnected: %s", unicode(self))
clients.remove(self)
def lineReceived(self, line):
print "received", line
logger.debug("received %s %s" % (line, type(line)))
if not isinstance(line, str):
line = line.decode()
if line.startswith('cifs://'):
mount_smb(line)
else:
......@@ -243,7 +273,7 @@ if win:
def run_client():
from twisted.internet import reactor
print "connect to localhost:%d" % port
print("connect to localhost:%d" % port)
reactor.connectTCP("localhost", port, SubFactory())
reactor.run()
......
......@@ -3,6 +3,11 @@ import json
import logging
import platform
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger()
system = platform.system()
......@@ -13,12 +18,13 @@ class SerialLineReceiverBase(LineReceiver, object):
def __init__(self, *args, **kwargs):
if system == "FreeBSD":
self.delimiter = '\n'
self.delimiter = b'\n'
else:
self.delimiter = '\r'
self.delimiter = b'\r'
super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args):
# logger.debug("send_response %s %s" % (response, args))
self.transport.write(json.dumps({'response': response,
'args': args}) + '\r\n')
......@@ -33,6 +39,11 @@ class SerialLineReceiverBase(LineReceiver, object):
raise NotImplementedError("Subclass must implement abstract method")
def lineReceived(self, data):
logger.debug("lineReceived: %s", data)
if (isinstance(data, unicode)):
data = data.strip('\0')
else:
data = data.strip(b'\0')
try:
data = json.loads(data)
args = data.get('args', {})
......@@ -43,6 +54,7 @@ class SerialLineReceiverBase(LineReceiver, object):
logger.debug('[serial] valid json: %s' % (data, ))
except (ValueError, KeyError) as e:
logger.error('[serial] invalid json: %s (%s)' % (data, e))
self.clearLineBuffer()
return
if command is not None and isinstance(command, unicode):
......@@ -50,7 +62,8 @@ class SerialLineReceiverBase(LineReceiver, object):
try:
self.handle_command(command, args)
except Exception as e:
logger.exception(u'Unhandled exception: ')
logger.exception("Unhandled exception during line recived: ")
elif response is not None and isinstance(response, unicode):
logger.debug('received reply: %s (%s)' % (response, args))
self.clearLineBuffer()
self.handle_response(response, args)
......@@ -65,7 +65,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main():
if len(sys.argv) == 1:
if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting...
# for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements.
......
......@@ -7,8 +7,8 @@ from os.path import join
import logging
import tarfile
from StringIO import StringIO
from base64 import decodestring
from io import StringIO
from base64 import b64decode
from hashlib import md5
from datetime import datetime
import win32api
......@@ -20,6 +20,11 @@ from twisted.internet import reactor
from .network import change_ip_windows
from context import BaseContext
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger()
......@@ -28,6 +33,7 @@ class Context(BaseContext):
@staticmethod
def change_password(password):
BaseContext.placed = True
from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo()
......@@ -39,6 +45,7 @@ class Context(BaseContext):
@staticmethod
def change_ip(interfaces, dns):
BaseContext.placed = True
nameservers = dns.replace(' ', '').split(',')
change_ip_windows(interfaces, nameservers)
......@@ -111,7 +118,7 @@ class Context(BaseContext):
local_checksum = md5(data).hexdigest()
if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data))
decoded = StringIO(b64decode(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory)
......
......@@ -7,12 +7,14 @@ Serial port support for Windows.
Requires PySerial and pywin32.
"""
# system imports
import win32file
import win32event
import win32con
# system imports
from serial.serialutil import to_bytes # type: ignore[import]
from time import sleep
# twisted imports
from twisted.internet import abstract
......@@ -27,9 +29,13 @@ class SerialPort(abstract.FileDescriptor):
connected = 1
def __init__(self, protocol, deviceNameOrPortNumber, reactor):
def __init__(self, protocol, deviceName, reactor):
self.initHard(protocol, deviceName, reactor)
self.initSoft(protocol, deviceName, reactor)
def initHard(self, protocol, deviceName, reactor):
self.hComPort = win32file.CreateFile(
deviceNameOrPortNumber,
deviceName,
win32con.GENERIC_READ | win32con.GENERIC_WRITE,
0, # exclusive access
None, # no security
......@@ -38,101 +44,103 @@ class SerialPort(abstract.FileDescriptor):
0)
self.reactor = reactor
self.protocol = protocol
self.outQueue = []
self.closed = 0
self.closedNotifies = 0
self.writeInProgress = 0
self.protocol = protocol
self.deviceName = deviceName
self._overlappedRead = win32file.OVERLAPPED()
self._overlappedRead.hEvent = win32event.CreateEvent(None, 1, 0, None)
self._overlappedWrite = win32file.OVERLAPPED()
self._overlappedWrite.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.reactor.addEvent(self._overlappedRead.hEvent, self, 'serialReadEvent')
self.reactor.addEvent(self._overlappedWrite.hEvent, self, 'serialWriteEvent')
self.reactor.addEvent(
self._overlappedRead.hEvent,
self,
'serialReadEvent')
self.reactor.addEvent(
self._overlappedWrite.hEvent,
self,
'serialWriteEvent')
def initSoft(self, protocol, deviceName, reactor):
self.outQueue = []
self.closed = 0
self.closedNotifies = 0
self.writeInProgress = 0
self.conneted = 1
self._reconnInProgress = False
self.protocol.makeConnection(self)
self._finishPortSetup()
self._startReading()
def _finishPortSetup(self):
"""
Finish setting up the serial port.
This is a separate method to facilitate testing.
"""
def _startReading(self, len=4096):
rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1),
win32file.AllocateReadBuffer(len),
self._overlappedRead)
def serialReadEvent(self):
# get that character we set up
logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
try:
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
0)
except Exception:
import time
time.sleep(10)
n = win32file.GetOverlappedResult(self.hComPort, self._overlappedRead, 1)
except Exception as e:
logger.debug("Exception %s" % e)
sleep(10)
logger.debug(self.connLost())
n = 0
if n:
first = str(self.read_buf[:n])
# now we should get everything that is already in the buffer (max
# 4096)
if n > 0:
# handle the received data:
self.protocol.dataReceived(to_bytes(self.read_buf[:n]))
# set up next read
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(4096),
self._overlappedRead)
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
1)
# handle all the received data:
self.protocol.dataReceived(first + str(buf[:n]))
# set up next one
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1),
self._overlappedRead)
self._startReading()
def write(self, data):
if data:
if isinstance(data, str):
data = str.encode(data)
if self.writeInProgress:
self.outQueue.append(data)
logger.debug("added to queue")
else:
self.writeInProgress = 1
win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file")
ret, n = win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file %s", ret)
def serialWriteEvent(self):
logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
if self._overlappedWrite.Internal < 0 and self._overlappedWrite.InternalHigh == 0 : # DANGER: Not documented variables
logger.debug(self.connLost())
self.writeInProgress = 0
return
try:
dataToWrite = self.outQueue.pop(0)
except IndexError:
self.writeInProgress = 0
return
else:
win32file.WriteFile(
self.hComPort,
dataToWrite,
self._overlappedWrite)
win32file.WriteFile(self.hComPort, dataToWrite, self._overlappedWrite)
def connLost(self):
if self._reconnInProgress:
return None
self._reconnInProgress = True
return self.reactor.callLater(30, self.connectionLostEvent, self)
def connectionLostEvent(self, reason):
abstract.FileDescriptor.connectionLost(self, reason)
self.protocol.connectionLost(reason)
logger.debug("Reconecting after 30s")
# sleep(30)
self.initSoft(self.protocol, self.deviceName, self.reactor)
def connectionLost(self, reason):
def connectionLost(self, reason=None):
"""
Called when the serial port disconnects.
Will call C{connectionLost} on the protocol that is handling the
serial data.
"""
# import pdb; pdb.set_trace()
win32file.CancelIo(self.hComPort)
self.reactor.removeEvent(self._overlappedRead.hEvent)
self.reactor.removeEvent(self._overlappedWrite.hEvent)
win32file.CloseHandle(self._overlappedRead.hEvent)
win32file.CloseHandle(self._overlappedWrite.hEvent)
abstract.FileDescriptor.connectionLost(self, reason)
win32file.CloseHandle(self.hComPort)
self.protocol.connectionLost(reason)
logger.debug("Hard reconecting after 10s")
sleep(10)
self.initHard(self.protocol, self.deviceName, self.reactor)
self.initSoft(self.protocol, self.deviceName, self.reactor)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment