Commit ecbb709b by Szeberényi Imre

Python3 version

parent 6bdaacab
......@@ -8,6 +8,7 @@ import win32event
import win32service
import win32serviceutil
#import agent
from agent import main as agent_main, reactor
logger = logging.getLogger()
......@@ -17,7 +18,7 @@ formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO')
level = os.environ.get('LOGLEVEL', 'DEBUG')
logger.setLevel(level)
logger.info("%s loaded", __file__)
......@@ -46,7 +47,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main():
if len(sys.argv) == 1:
if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting...
# for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements.
......
......@@ -12,68 +12,106 @@ system = platform.system() # noqa
if system == "Linux" or system == "FreeBSD": # noqa
try: # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements/linux.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
from twisted.internet import reactor, defer
from twisted.internet.task import LoopingCall
import uptime
import logging
from inspect import getargspec, isfunction
from inspect import getargs, isfunction
from utils import SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext
# (relative import error.
from context import BaseContext, get_context, get_serial # noqa
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
try:
from inspect import getfullargspec as getargspec
except ImportError:
from inspect import getargspec as getargspec
#############################################################
Context = get_context()
logging.basicConfig()
logging.basicConfig(
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] %(module)s.%(funcName)s:%(lineno)d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
)
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
level = environ.get('LOGLEVEL', 'DEBUG')
logger.setLevel(level)
class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self):
self.transport.write('\r\n')
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(),
'system': system})
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.tickId = LoopingCall(self.tick)
self.mayStartNowId = LoopingCall(self.mayStartNow)
reactor.addSystemEventTrigger("before", "shutdown", self.shutdown)
self.running = True
def shutdown():
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
reactor.addSystemEventTrigger("before", "shutdown", shutdown)
def connectionMade(self):
logger.debug("connectionMade")
self.clearLineBuffer()
self.tickId.start(5, now=False)
self.mayStartNowId.start(10, now=False)
self.send_startMsg()
def connectionLost(self, reason):
reactor.stop()
logger.debug("connectionLost")
if self.tickId.running:
self.tickId.stop()
if self.mayStartNowId.running:
self.mayStartNowId.stop()
def connectionLost2(self, reason):
self.send_command(command='agent_stopped',
args={})
self.send_command(command='agent_stopped', args={})
def mayStartNow(self):
if BaseContext.placed:
self.mayStartNowId.stop()
logger.info("Placed")
return
self.send_startMsg()
def tick(self):
logger.debug("Sending tick")
try:
self.send_status()
except Exception:
logger.exception("Twisted hide exception")
except Exception as e:
logger.debug("Exception durig tick: %s" % e)
# logger.exception("Twisted hide exception")
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.lc = LoopingCall(self.tick)
self.lc.start(5, now=False)
def shutdown(self):
self.running = False
logger.debug("shutdown")
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
def send_startMsg(self):
logger.debug("Sending start message...")
# Hack for flushing the lower level buffersr
self.transport.dataBuffer = b""
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite
self.transport._tempDataLen = 0
self.transport.write('\r\n')
if self.running:
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': system})
def send_status(self):
import psutil
......@@ -90,6 +128,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
logger.debug("send_status finished")
def _check_args(self, func, args):
logger.debug("_check_args %s %s" % (func, args))
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
......@@ -97,7 +136,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
# check for unexpected keyword arguments
argspec = getargspec(func)
if argspec.keywords is None: # _operation doesn't take ** args
try:
_kwargs = argspec.keywords
except AttributeError:
_kwargs = argspec.varkw
if _kwargs is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs:
raise TypeError(
......@@ -111,9 +154,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
logger.debug("_check_args finished")
def _get_command(self, command, args):
if not isinstance(command, basestring) or command.startswith('_'):
logger.debug("_get_command %s %s" % (command, args))
if not isinstance(command, unicode) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command)
try:
func = getattr(Context, command)
......@@ -142,8 +187,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
return "<%s>" % type(fun).__name__
def handle_command(self, command, args):
logger.debug("handle_command %s %s" % (command, args))
func = self._get_command(command, args)
logger.debug("Call cmd: %s %s" % (func, args))
retval = func(**args)
logger.debug("Retval: %s" % retval)
self.send_response(
response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)})
......@@ -151,7 +199,6 @@ class SerialLineReceiver(SerialLineReceiverBase):
def handle_response(self, response, args):
pass
def main():
# Get proper serial class and port name
(serial, port) = get_serial()
......
""" This is the defautl context file. It replaces the Context class
to the platform specific one.
"""
import logging
import platform
logger = logging.getLogger()
def _get_virtio_device():
path = None
......@@ -18,6 +20,8 @@ def _get_virtio_device():
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
break
logger.debug("DEV found: %s", path)
return path
......@@ -36,6 +40,7 @@ def get_context():
def get_serial():
system = platform.system()
logger.debug("Get_serial system: %s", system)
port = None
if system == 'Windows':
port = _get_virtio_device()
......@@ -49,8 +54,10 @@ def get_serial():
elif system == "Linux":
port = "/dev/virtio-ports/agent"
try:
open(port, 'rw').close()
except (OSError, IOError):
print("Open!")
open(port, 'r').close()
except (OSError, IOError) as e:
print(e)
from twisted.internet.serialport import SerialPort
port = '/dev/ttyS0'
else:
......@@ -70,6 +77,8 @@ def get_serial():
class BaseContext(object):
placed = False # if we reciwed password or net commands
@staticmethod
def change_password(password):
pass
......
......@@ -3,7 +3,11 @@
from os import mkdir, remove, chown
from pwd import getpwnam
import platform
try:
from distro import distro as platform
except ImportError:
import platform
from shutil import rmtree, move
import subprocess
import sys
......@@ -17,12 +21,21 @@ import fileinput
import tarfile
from os.path import expanduser, join, exists
from glob import glob
from StringIO import StringIO
from base64 import decodestring
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
try:
from base64 import decodestring
except ImportError:
from base64 import decodebytes as decodestring
from hashlib import md5
from ssh import PubKey
from .ssh import PubKey
from .network import change_ip_ubuntu, change_ip_rhel
from context import BaseContext
......@@ -77,7 +90,8 @@ class Context(BaseContext):
def change_password(password):
proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password)
proc.communicate(str.encode('cloud:%s\n' % password))
BaseContext.placed = True
@staticmethod
def restart_networking():
......@@ -90,10 +104,12 @@ class Context(BaseContext):
@staticmethod
def change_ip(interfaces, dns):
nameservers = dns.replace(' ', '').split(',')
logger.debug("Distro: %s" % distro)
if distro == 'debian':
change_ip_ubuntu(interfaces, nameservers)
elif distro == 'rhel':
change_ip_rhel(interfaces, nameservers)
BaseContext.placed = True
@staticmethod
def set_time(time):
......@@ -123,7 +139,7 @@ class Context(BaseContext):
# TODO
for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip()
print(line.rstrip())
with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data)
......
......@@ -27,7 +27,7 @@ def remove_interfaces_ubuntu(devices):
if line.startswith('#') or line == '' or line.isspace() or not words:
# keep line
print line
print(line)
continue
if (words[0] in ('auto', 'allow-hotplug') and
......@@ -49,7 +49,7 @@ def remove_interfaces_ubuntu(devices):
continue
# keep line
print line
print(line)
def change_ip_ubuntu(interfaces, nameservers):
......
......@@ -34,9 +34,16 @@ class SerialPort(abstract.FileDescriptor):
self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK)
self.reactor = reactor
self.protocol = protocol
self.connected = 1
self.protocol.makeConnection(self)
self.startReading()
def write(self, data):
if data:
if isinstance(data, str):
data = str.encode(data)
self.writeSomeData(data)
def fileno(self):
return self._serial
......@@ -60,9 +67,12 @@ class SerialPort(abstract.FileDescriptor):
serial data.
"""
abstract.FileDescriptor.connectionLost(self, reason)
self.protocol.connectionLost(reason)
os.close(self._serial)
sleep(2)
logger.debug("Reconecting after 5s")
sleep(5)
self._serial = os.open(
self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK)
self.connected = 1
self.protocol.makeConnection(self)
self.startReading()
logger.info("Reconnecting")
from base64 import decodestring
try:
from base64 import decodestring
except ImportError:
from base64 import decodebytes as decodestring
from struct import unpack
import binascii
import unittest
......
......@@ -5,7 +5,6 @@
# Notify user about vm expiring
##
import cookielib
import errno
import json
import logging
......@@ -13,8 +12,27 @@ import multiprocessing
import os
import platform
import subprocess
import urllib2
from urlparse import urlsplit
try:
import cookielib
except ImportError:
import http.cookiejar as cookielib
try:
import urllib2
except ImportError:
import urllib.request as urllib2
try:
from urlparse import urlsplit
except ImportError:
from urllib.parse import urlsplit
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger()
logger.debug("notify imported")
......@@ -43,6 +61,8 @@ def get_temp_dir():
def wall(text):
if isinstance(text, str):
text = str.encode(text)
if win:
return
if text is None:
......@@ -58,7 +78,7 @@ def accept():
from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path):
print "There is no recent notification to accept."
print("There is no recent notification to accept.")
return False
# Load the saved url
......@@ -70,9 +90,9 @@ def accept():
opener.open(url) # GET to collect cookies
cookies = cj._cookies_for_request(urllib2.Request(url))
token = [c for c in cookies if c.name == "csrftoken"][0].value
req = urllib2.Request(url, "", {
"accept": "application/json", "referer": url,
"x-csrftoken": token})
req = urllib2.Request(url, b"", {
b"accept": b"application/json", b"referer": url,
b"x-csrftoken": token})
rsp = opener.open(req)
data = json.load(rsp)
newtime = data["new_suspend_time"]
......@@ -83,10 +103,10 @@ def accept():
new_local_time = parsed_time.astimezone(
get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e:
print "Parsing time failed: %s" % e
print("Parsing time failed: %s" % e)
except Exception as e:
print e
print "Renewal failed. Please try it manually at %s" % url
print(e)
print("Renewal failed. Please try it manually at %s" % url)
logger.exception("renew failed")
return False
else:
......@@ -144,6 +164,7 @@ def open_in_browser(url):
def mount_smb(url):
data = urlsplit(url)
share = data.path.lstrip('/')
print("host: %s share %s user: %s pw: %s" % (data.hostname, share, data.username, data.password))
subprocess.call(('net', 'use', 'Z:', '/delete'))
try:
p = subprocess.Popen((
......@@ -229,8 +250,19 @@ if win:
class SubProtocol(basic.LineReceiver):
def connectionMade(self):
logger.info("Subclient connected: %s", unicode(self))
clients.add(self)
def connectionLost(self, reason):
logger.info("Subclient disconnected: %s", unicode(self))
clients.remove(self)
def lineReceived(self, line):
print "received", line
logger.debug("received %s %s" % (line, type(line)))
if not isinstance(line, str):
line = line.decode()
if line.startswith('cifs://'):
mount_smb(line)
else:
......@@ -243,7 +275,7 @@ if win:
def run_client():
from twisted.internet import reactor
print "connect to localhost:%d" % port
print("connect to localhost:%d" % port)
reactor.connectTCP("localhost", port, SubFactory())
reactor.run()
......
Twisted
pyserial
psutil
uptime
netifaces
netaddr
infi.devicemanager
tzlocal
pytz
distro
Twisted
pyserial
psutil
uptime
netifaces
netaddr
infi.devicemanager
tzlocal
pytz
pywin32
wmi
......@@ -3,6 +3,11 @@ import json
import logging
import platform
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger()
system = platform.system()
......@@ -13,12 +18,13 @@ class SerialLineReceiverBase(LineReceiver, object):
def __init__(self, *args, **kwargs):
if system == "FreeBSD":
self.delimiter = '\n'
self.delimiter = b'\n'
else:
self.delimiter = '\r'
self.delimiter = b'\r'
super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args):
# logger.debug("send_response %s %s" % (response, args))
self.transport.write(json.dumps({'response': response,
'args': args}) + '\r\n')
......@@ -33,6 +39,11 @@ class SerialLineReceiverBase(LineReceiver, object):
raise NotImplementedError("Subclass must implement abstract method")
def lineReceived(self, data):
logger.debug("lineReceived: %s", data)
if (isinstance(data, unicode)):
data = data.strip('\0')
else:
data = data.strip(b'\0')
try:
data = json.loads(data)
args = data.get('args', {})
......@@ -43,6 +54,7 @@ class SerialLineReceiverBase(LineReceiver, object):
logger.debug('[serial] valid json: %s' % (data, ))
except (ValueError, KeyError) as e:
logger.error('[serial] invalid json: %s (%s)' % (data, e))
self.clearLineBuffer()
return
if command is not None and isinstance(command, unicode):
......@@ -50,7 +62,8 @@ class SerialLineReceiverBase(LineReceiver, object):
try:
self.handle_command(command, args)
except Exception as e:
logger.exception(u'Unhandled exception: ')
logger.exception("Unhandled exception during line recived: ")
elif response is not None and isinstance(response, unicode):
logger.debug('received reply: %s (%s)' % (response, args))
self.clearLineBuffer()
self.handle_response(response, args)
......@@ -65,7 +65,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main():
if len(sys.argv) == 1:
if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting...
# for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements.
......
......@@ -7,8 +7,8 @@ from os.path import join
import logging
import tarfile
from StringIO import StringIO
from base64 import decodestring
from io import StringIO
from base64 import b64decode
from hashlib import md5
from datetime import datetime
import win32api
......@@ -20,6 +20,11 @@ from twisted.internet import reactor
from .network import change_ip_windows
from context import BaseContext
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger()
......@@ -28,6 +33,7 @@ class Context(BaseContext):
@staticmethod
def change_password(password):
BaseContext.placed = True
from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo()
......@@ -39,6 +45,7 @@ class Context(BaseContext):
@staticmethod
def change_ip(interfaces, dns):
BaseContext.placed = True
nameservers = dns.replace(' ', '').split(',')
change_ip_windows(interfaces, nameservers)
......@@ -111,7 +118,7 @@ class Context(BaseContext):
local_checksum = md5(data).hexdigest()
if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data))
decoded = StringIO(b64decode(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory)
......
......@@ -7,12 +7,14 @@ Serial port support for Windows.
Requires PySerial and pywin32.
"""
# system imports
import win32file
import win32event
import win32con
# system imports
from serial.serialutil import to_bytes # type: ignore[import]
from time import sleep
# twisted imports
from twisted.internet import abstract
......@@ -27,9 +29,13 @@ class SerialPort(abstract.FileDescriptor):
connected = 1
def __init__(self, protocol, deviceNameOrPortNumber, reactor):
def __init__(self, protocol, deviceName, reactor):
self.initHard(protocol, deviceName, reactor)
self.initSoft(protocol, deviceName, reactor)
def initHard(self, protocol, deviceName, reactor):
self.hComPort = win32file.CreateFile(
deviceNameOrPortNumber,
deviceName,
win32con.GENERIC_READ | win32con.GENERIC_WRITE,
0, # exclusive access
None, # no security
......@@ -38,101 +44,103 @@ class SerialPort(abstract.FileDescriptor):
0)
self.reactor = reactor
self.protocol = protocol
self.outQueue = []
self.closed = 0
self.closedNotifies = 0
self.writeInProgress = 0
self.protocol = protocol
self.deviceName = deviceName
self._overlappedRead = win32file.OVERLAPPED()
self._overlappedRead.hEvent = win32event.CreateEvent(None, 1, 0, None)
self._overlappedWrite = win32file.OVERLAPPED()
self._overlappedWrite.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.reactor.addEvent(self._overlappedRead.hEvent, self, 'serialReadEvent')
self.reactor.addEvent(self._overlappedWrite.hEvent, self, 'serialWriteEvent')
self.reactor.addEvent(
self._overlappedRead.hEvent,
self,
'serialReadEvent')
self.reactor.addEvent(
self._overlappedWrite.hEvent,
self,
'serialWriteEvent')
def initSoft(self, protocol, deviceName, reactor):
self.outQueue = []
self.closed = 0
self.closedNotifies = 0
self.writeInProgress = 0
self.conneted = 1
self._reconnInProgress = False
self.protocol.makeConnection(self)
self._finishPortSetup()
self._startReading()
def _finishPortSetup(self):
"""
Finish setting up the serial port.
This is a separate method to facilitate testing.
"""
def _startReading(self, len=4096):
rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1),
win32file.AllocateReadBuffer(len),
self._overlappedRead)
def serialReadEvent(self):
# get that character we set up
logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
try:
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
0)
except Exception:
import time
time.sleep(10)
n = win32file.GetOverlappedResult(self.hComPort, self._overlappedRead, 1)
except Exception as e:
logger.debug("Exception %s" % e)
sleep(10)
logger.debug(self.connLost())
n = 0
if n:
first = str(self.read_buf[:n])
# now we should get everything that is already in the buffer (max
# 4096)
if n > 0:
# handle the received data:
self.protocol.dataReceived(to_bytes(self.read_buf[:n]))
# set up next read
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(4096),
self._overlappedRead)
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
1)
# handle all the received data:
self.protocol.dataReceived(first + str(buf[:n]))
# set up next one
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1),
self._overlappedRead)
self._startReading()
def write(self, data):
if data:
if isinstance(data, str):
data = str.encode(data)
if self.writeInProgress:
self.outQueue.append(data)
logger.debug("added to queue")
else:
self.writeInProgress = 1
win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file")
ret, n = win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file %s", ret)
def serialWriteEvent(self):
logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
if self._overlappedWrite.Internal < 0 and self._overlappedWrite.InternalHigh == 0 : # DANGER: Not documented variables
logger.debug(self.connLost())
self.writeInProgress = 0
return
try:
dataToWrite = self.outQueue.pop(0)
except IndexError:
self.writeInProgress = 0
return
else:
win32file.WriteFile(
self.hComPort,
dataToWrite,
self._overlappedWrite)
win32file.WriteFile(self.hComPort, dataToWrite, self._overlappedWrite)
def connLost(self):
if self._reconnInProgress:
return None
self._reconnInProgress = True
return self.reactor.callLater(30, self.connectionLostEvent, self)
def connectionLostEvent(self, reason):
abstract.FileDescriptor.connectionLost(self, reason)
self.protocol.connectionLost(reason)
logger.debug("Reconecting after 30s")
# sleep(30)
self.initSoft(self.protocol, self.deviceName, self.reactor)
def connectionLost(self, reason):
def connectionLost(self, reason=None):
"""
Called when the serial port disconnects.
Will call C{connectionLost} on the protocol that is handling the
serial data.
"""
# import pdb; pdb.set_trace()
win32file.CancelIo(self.hComPort)
self.reactor.removeEvent(self._overlappedRead.hEvent)
self.reactor.removeEvent(self._overlappedWrite.hEvent)
win32file.CloseHandle(self._overlappedRead.hEvent)
win32file.CloseHandle(self._overlappedWrite.hEvent)
abstract.FileDescriptor.connectionLost(self, reason)
win32file.CloseHandle(self.hComPort)
self.protocol.connectionLost(reason)
logger.debug("Hard reconecting after 10s")
sleep(10)
self.initHard(self.protocol, self.deviceName, self.reactor)
self.initSoft(self.protocol, self.deviceName, self.reactor)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment