Commit 38b14478 by Bach Dániel

Merge branch 'virtio' into 'master'

Virtio

See merge request !4
parents e0d5ca1a 8879b198
""" This is the defautl context file. It replaces the Context class
to the platform specific one.
"""
import platform
from os.path import exists
def _get_virtio_device():
path = None
GUID = '{6FDE7521-1B65-48ae-B628-80BE62016026}'
from infi.devicemanager import DeviceManager
dm = DeviceManager()
dm.root.rescan()
# Search Virtio-Serial by name TODO: search by class_guid
for i in dm.all_devices:
if i.has_property("description"):
if "virtio-serial".upper() in i.description.upper():
path = ("\\\\?\\" +
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
return path
def get_context():
system = platform.system()
if system == "Windows":
from windows._win32context import Context
elif system == "Linux":
from linux._linuxcontext import Context
else:
raise NotImplementedError("Platform %s is not supported.", system)
return Context
def get_serial():
system = platform.system()
port = None
if system == 'Windows':
port = _get_virtio_device()
if port:
from windows.win32virtio import SerialPort
else:
from twisted.internet.serial import SerialPort
import pythoncom
pythoncom.CoInitialize()
port = r'\\.\COM1'
elif system == "Linux":
port = "/dev/virtio-ports/agent"
if exists(port):
from linux.posixvirtio import SerialPort
else:
from twisted.internet.serial import SerialPort
port = '/dev/ttyS0'
else:
raise NotImplementedError("Platform %s is not supported.", system)
return (SerialPort, port)
class BaseContext(object):
@staticmethod
def change_password(password):
pass
@staticmethod
def restart_networking():
pass
@staticmethod
def change_ip(interfaces, dns):
pass
@staticmethod
def set_time(time):
pass
@staticmethod
def set_hostname(hostname):
pass
@staticmethod
def mount_store(host, username, password):
pass
@staticmethod
def get_keys():
pass
@staticmethod
def add_keys(keys):
pass
@staticmethod
def del_keys(keys):
pass
@staticmethod
def cleanup():
pass
@staticmethod
def start_access_server():
pass
@staticmethod
def append(data, filename, chunk_number, uuid):
pass
@staticmethod
def update(filename, executable, checksum, uuid):
pass
@staticmethod
def ipaddresses():
pass
@staticmethod
def get_agent_version():
try:
with open('version.txt') as f:
return f.readline()
except IOError:
return None
@staticmethod
def send_expiration(url):
import notify
notify.notify(url)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os import mkdir, environ, chdir
import platform
from shutil import copy, rmtree, move
import subprocess
import sys
system = platform.system()
working_directory = sys.path[0]
try:
chdir(working_directory)
subprocess.call(('pip', 'install', '-r', 'requirements.txt'))
if system == 'Linux':
copy("/root/agent/misc/vm_renewal", "/usr/local/bin/")
except:
pass # hope it works
import logging
import fileinput
import tarfile
from os.path import expanduser, join, exists
from glob import glob
from StringIO import StringIO
from base64 import decodestring
from hashlib import md5
from ssh import PubKey
from .network import change_ip_ubuntu, change_ip_rhel
from context import BaseContext
from twisted.internet import reactor
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
SSH_DIR = expanduser('~cloud/.ssh')
AUTHORIZED_KEYS = join(SSH_DIR, 'authorized_keys')
STORE_DIR = '/store'
mount_template_linux = (
'//%(host)s/%(username)s %(dir)s cifs username=%(username)s'
',password=%(password)s,iocharset=utf8,uid=cloud 0 0\n')
distros = {'Scientific Linux': 'rhel',
'CentOS': 'rhel',
'CentOS Linux': 'rhel',
'Debian': 'debian',
'Ubuntu': 'debian'}
if system == 'Linux':
distro = distros[platform.linux_distribution()[0]]
class Context(BaseContext):
# http://stackoverflow.com/questions/12081310/
# python-module-to-change-system-date-and-time
def _linux_set_time(time):
import ctypes
import ctypes.util
CLOCK_REALTIME = 0
class timespec(ctypes.Structure):
_fields_ = [("tv_sec", ctypes.c_long),
("tv_nsec", ctypes.c_long)]
librt = ctypes.CDLL(ctypes.util.find_library("rt"))
ts = timespec()
ts.tv_sec = int(time)
ts.tv_nsec = 0
librt.clock_settime(CLOCK_REALTIME, ctypes.byref(ts))
@staticmethod
def change_password(password):
proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password)
@staticmethod
def restart_networking():
if distro == 'debian':
subprocess.call(['/etc/init.d/networking', 'restart'])
elif distro == 'rhel':
subprocess.call(['/bin/systemctl', 'restart', 'network'])
pass
@staticmethod
def change_ip(interfaces, dns):
if distro == 'debian':
change_ip_ubuntu(interfaces, dns)
elif distro == 'rhel':
change_ip_rhel(interfaces, dns)
@staticmethod
def set_time(time):
Context._linux_set_time(float(time))
try:
subprocess.call(['/etc/init.d/ntp', 'restart'])
except:
pass
@staticmethod
def set_hostname(hostname):
if distro == 'debian':
with open('/etc/hostname', 'w') as f:
f.write(hostname)
elif distro == 'rhel':
for line in fileinput.input('/etc/sysconfig/network',
inplace=1):
if line.startswith('HOSTNAME='):
print 'HOSTNAME=%s' % hostname
else:
print line.rstrip()
with open('/etc/hosts', 'w') as f:
f.write("127.0.0.1 localhost\n"
"127.0.1.1 %s\n" % hostname)
subprocess.call(['/bin/hostname', hostname])
@staticmethod
def mount_store(host, username, password):
data = {'host': host, 'username': username, 'password': password}
data['dir'] = STORE_DIR
if not exists(STORE_DIR):
mkdir(STORE_DIR)
# TODO
for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip()
with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data)
subprocess.call('mount -a', shell=True)
@staticmethod
def get_keys():
retval = []
try:
with open(AUTHORIZED_KEYS, 'r') as f:
for line in f.readlines():
try:
retval.append(PubKey.from_str(line))
except:
logger.exception(u'Invalid ssh key: ')
except IOError:
pass
return retval
@staticmethod
def _save_keys(keys):
print keys
try:
mkdir(SSH_DIR)
except OSError:
pass
with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys:
f.write(unicode(key) + '\n')
@staticmethod
def add_keys(keys):
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
if p not in new_keys:
new_keys.append(p)
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def del_keys(keys):
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
try:
new_keys.remove(p)
except ValueError:
pass
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def cleanup():
filelist = ([
'/root/.bash_history'
'/home/cloud/.bash_history'
'/root/.ssh'
'/home/cloud/.ssh']
+ glob('/etc/ssh/ssh_host_*'))
for f in filelist:
rmtree(f, ignore_errors=True)
subprocess.call(('/usr/bin/ssh-keygen', '-A'))
@staticmethod
def start_access_server():
try:
subprocess.call(('/sbin/start', 'ssh'))
except OSError:
subprocess.call(('/bin/systemctl', 'start', 'sshd.service'))
@staticmethod
def append(data, filename, chunk_number, uuid):
if chunk_number == 0:
flag = "w"
else:
flag = "a"
with open(filename, flag) as myfile:
myfile.write(data)
@staticmethod
def update(filename, executable, checksum, uuid):
new_dir = working_directory + '.new'
old_dir = working_directory + '.old'
with open(filename, "r") as f:
data = f.read()
local_checksum = md5(data).hexdigest()
if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(new_dir)
except tarfile.ReadError as e:
logger.error(e)
move(working_directory, old_dir)
move(new_dir, working_directory)
logger.info("Transfer completed!")
reactor.stop()
@staticmethod
def ipaddresses():
import netifaces
args = {}
interfaces = netifaces.interfaces()
for i in interfaces:
if i == 'lo':
continue
args[i] = []
addresses = netifaces.ifaddresses(i)
args[i] = ([x['addr']
for x in addresses.get(netifaces.AF_INET, [])] +
[x['addr']
for x in addresses.get(netifaces.AF_INET6, [])
if '%' not in x['addr']])
return args
@staticmethod
def get_agent_version():
try:
with open('version.txt') as f:
return f.readline()
except IOError:
return None
@staticmethod
def send_expiration(url):
import notify
notify.notify(url)
import netifaces import netifaces
from netaddr import IPNetwork, IPAddress from netaddr import IPNetwork
import fileinput import fileinput
import logging import logging
from subprocess import check_output, CalledProcessError
logger = logging.getLogger() logger = logging.getLogger()
...@@ -122,68 +121,3 @@ def change_ip_rhel(interfaces, dns): ...@@ -122,68 +121,3 @@ def change_ip_rhel(interfaces, dns):
'ip': ip, 'ip': ip,
'netmask': str(ip_with_prefix.netmask), 'netmask': str(ip_with_prefix.netmask),
'gw': conf['gw4']}) 'gw': conf['gw4']})
def get_interfaces_windows(interfaces):
import wmi
nics = wmi.WMI().Win32_NetworkAdapterConfiguration(IPEnabled=True)
for nic in nics:
conf = interfaces.get(nic.MACAddress)
if conf:
yield nic, conf
def change_ip_windows(interfaces, dns):
for nic, conf in get_interfaces_windows(interfaces):
link_local = IPNetwork('fe80::/16')
new_addrs = [IPNetwork(ip) for ip in conf['addresses']]
new_addrs_str = set(str(ip) for ip in new_addrs)
old_addrs = [IPNetwork('%s/%s' % (ip, nic.IPSubnet[i]))
for i, ip in enumerate(nic.IPAddress)
if IPAddress(ip) not in link_local]
old_addrs_str = set(str(ip) for ip in old_addrs)
changed = (
new_addrs_str != old_addrs_str or
set(nic.DefaultIPGateway) != set([conf['gw4'], conf['gw6']]))
if changed or 1:
logger.info('new config for <%s(%s)>: %s', nic.Description,
nic.MACAddress, ', '.join(new_addrs_str))
# IPv4
ipv4_addrs = [str(ip.ip) for ip in new_addrs
if ip.version == 4]
ipv4_masks = [str(ip.netmask) for ip in new_addrs
if ip.version == 4]
logger.debug('<%s>.EnableStatic(%s, %s) called', nic.Description,
ipv4_addrs, ipv4_masks)
retval = nic.EnableStatic(
IPAddress=ipv4_addrs, SubnetMask=ipv4_masks)
assert retval == (0, )
nic.SetGateways(DefaultIPGateway=[conf['gw4']])
assert retval == (0, )
# IPv6
for ip in new_addrs:
if ip.version == 6 and str(ip) not in old_addrs_str:
logger.debug('add %s (%s)', ip, nic.Description)
check_output(
'netsh interface ipv6 add address '
'interface=%s address=%s'
% (nic.InterfaceIndex, ip), shell=True)
for ip in old_addrs:
if ip.version == 6 and str(ip) not in new_addrs_str:
logger.debug('delete %s (%s)', ip, nic.Description)
check_output(
'netsh interface ipv6 delete address '
'interface=%s address=%s'
% (nic.InterfaceIndex, ip.ip), shell=True)
try:
check_output('netsh interface ipv6 del route ::/0 interface=%s'
% nic.InterfaceIndex, shell=True)
except CalledProcessError:
pass
check_output('netsh interface ipv6 add route ::/0 interface=%s %s'
% (nic.InterfaceIndex, conf['gw6']), shell=True)
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Virtio-Serial Port Protocol
"""
# system imports
import os
# dependent on pyserial ( http://pyserial.sf.net/ )
# only tested w/ 1.18 (5 Dec 2002)
# twisted imports
from twisted.internet import abstract, fdesc
class SerialPort(abstract.FileDescriptor):
"""
A select()able serial device, acting as a transport.
"""
connected = 1
def __init__(self, protocol, deviceNameOrPortNumber, reactor):
abstract.FileDescriptor.__init__(self, reactor)
self.port = deviceNameOrPortNumber
self._serial = os.open(
self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK)
self.reactor = reactor
self.protocol = protocol
self.protocol.makeConnection(self)
self.startReading()
def fileno(self):
return self._serial
def writeSomeData(self, data):
"""
Write some data to the serial device.
"""
return fdesc.writeToFD(self.fileno(), data)
def doRead(self):
"""
Some data's readable from serial device.
"""
return fdesc.readFromFD(self.fileno(), self.protocol.dataReceived)
def connectionLost(self, reason):
"""
Called when the serial port disconnects.
Will call C{connectionLost} on the protocol that is handling the
serial data.
"""
abstract.FileDescriptor.connectionLost(self, reason)
os.close(self._serial)
self.protocol.connectionLost(reason)
...@@ -4,3 +4,4 @@ psutil==1.2.1 ...@@ -4,3 +4,4 @@ psutil==1.2.1
uptime==3.0.1 uptime==3.0.1
netifaces==0.10.4 netifaces==0.10.4
netaddr==0.7.12 netaddr==0.7.12
infi.devicemanager
\ No newline at end of file
...@@ -7,6 +7,7 @@ logger = logging.getLogger() ...@@ -7,6 +7,7 @@ logger = logging.getLogger()
class SerialLineReceiverBase(LineReceiver, object): class SerialLineReceiverBase(LineReceiver, object):
delimiter = '\r' delimiter = '\r'
MAX_LENGTH = 1024*1024*128
def send_response(self, response, args): def send_response(self, response, args):
self.transport.write(json.dumps({'response': response, self.transport.write(json.dumps({'response': response,
......
import logging
from logging.handlers import NTEventLogHandler
from time import sleep
import os
import servicemanager
import socket
import sys
import win32event
import win32service
import win32serviceutil
logger = logging.getLogger()
fh = NTEventLogHandler(
"CIRCLE Watchdog", dllname=os.path.dirname(__file__))
formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
logger.info("%s loaded", __file__)
service_name = "circle-agent"
stopped = False
def watch():
def check_service(service_name):
return win32serviceutil.QueryServiceStatus(service_name)[1] == 4
def start_service():
win32serviceutil.StartService(service_name)
while True:
if not check_service(service_name):
logger.info("Service %s is not running.", service_name)
start_service()
if stopped:
return
sleep(10)
class AppServerSvc (win32serviceutil.ServiceFramework):
_svc_name_ = "circle-watchdog"
_svc_display_name_ = "CIRCLE Watchdog"
def __init__(self, args):
win32serviceutil.ServiceFramework.__init__(self, args)
self.hWaitStop = win32event.CreateEvent(None, 0, 0, None)
socket.setdefaulttimeout(60)
def SvcStop(self):
self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING)
win32event.SetEvent(self.hWaitStop)
global stopped
stopped = True
logger.info("%s stopped", __file__)
def SvcDoRun(self):
servicemanager.LogMsg(servicemanager.EVENTLOG_INFORMATION_TYPE,
servicemanager.PYS_SERVICE_STARTED,
(self._svc_name_, ''))
logger.info("%s starting", __file__)
watch()
def main():
if len(sys.argv) == 1:
# service must be starting...
# for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements.
import win32traceutil # noqa
logger.info("service is starting...")
servicemanager.Initialize()
servicemanager.PrepareToHostSingle(AppServerSvc)
# Now ask the service manager to fire things up for us...
servicemanager.StartServiceCtrlDispatcher()
logger.info("service done!")
else:
win32serviceutil.HandleCommandLine(AppServerSvc)
if __name__ == '__main__':
try:
main()
except (SystemExit, KeyboardInterrupt):
raise
except:
logger.exception("Exception:")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import platform
system = platform.system()
working_directory = r"C:\circle"
from os import environ
from os.path import join
import logging
import tarfile
from StringIO import StringIO
from base64 import decodestring
from hashlib import md5
from datetime import datetime
import win32api
import wmi
import netifaces
from twisted.internet import reactor
from .network import change_ip_windows
from context import BaseContext
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
class Context(BaseContext):
@staticmethod
def change_password(password):
from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo()
ads_obj.SetPassword(password)
@staticmethod
def restart_networking():
pass
@staticmethod
def change_ip(interfaces, dns):
change_ip_windows(interfaces, dns)
@staticmethod
def set_time(time):
t = datetime.utcfromtimestamp(float(time))
win32api.SetSystemTime(t.year, t.month, 0, t.day, t.hour,
t.minute, t.second, 0)
@staticmethod
def set_hostname(hostname):
wmi.WMI().Win32_ComputerSystem()[0].Rename(hostname)
@staticmethod
def mount_store(host, username, password):
import notify
url = 'cifs://%s:%s@%s/%s' % (username, password, host, username)
for c in notify.clients:
logger.debug("sending url %s to client %s", url, unicode(c))
c.sendLine(url.encode())
@staticmethod
def get_keys():
pass
@staticmethod
def add_keys(keys):
pass
@staticmethod
def del_keys(keys):
pass
@staticmethod
def cleanup():
# TODO
pass
@staticmethod
def start_access_server():
# TODO
pass
@staticmethod
def append(data, filename, chunk_number, uuid):
if chunk_number == 0:
flag = "w"
else:
flag = "a"
with open(filename, flag) as myfile:
myfile.write(data)
@staticmethod
def _update_registry(dir, executable):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from _winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS)
with OpenKeyEx(HKEY_LOCAL_MACHINE,
r'SYSTEM\CurrentControlSet\services\circle-agent',
0,
KEY_ALL_ACCESS) as key:
(old_executable, reg_type) = QueryValueEx(key, "ImagePath")
SetValueEx(key, "ImagePath", None, 2, join(dir, executable))
return old_executable
@staticmethod
def update(filename, executable, checksum, uuid):
with open(filename, "r") as f:
data = f.read()
local_checksum = md5(data).hexdigest()
if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory)
except tarfile.ReadError as e:
logger.error(e)
logger.info("Transfer completed!")
Context._update_registry(working_directory, executable)
logger.info('Updated')
reactor.stop()
@staticmethod
def ipaddresses():
args = {}
interfaces = netifaces.interfaces()
for i in interfaces:
if i == 'lo':
continue
args[i] = []
addresses = netifaces.ifaddresses(i)
args[i] = ([x['addr']
for x in addresses.get(netifaces.AF_INET, [])] +
[x['addr']
for x in addresses.get(netifaces.AF_INET6, [])
if '%' not in x['addr']])
return args
@staticmethod
def get_agent_version():
try:
with open(join(working_directory, 'version.txt')) as f:
return f.readline()
except IOError:
return None
from netaddr import IPNetwork, IPAddress
import logging
from subprocess import check_output, CalledProcessError
logger = logging.getLogger()
interfaces_file = '/etc/network/interfaces'
ifcfg_template = '/etc/sysconfig/network-scripts/ifcfg-%s'
# example:
# change_ip_ubuntu({
# u'02:00:00:02:A3:E8': {
# u'gw4': u'10.1.0.254', 'gw6': '2001::ffff',
# u'addresses': [u'10.1.0.84/24', '10.1.0.1/24', '2001::1/48']},
# u'02:00:00:02:A3:E9': {
# u'gw4': u'10.255.255.1', u'addresses': [u'10.255.255.9']}},
# '8.8.8.8')
def get_interfaces_windows(interfaces):
import wmi
nics = wmi.WMI().Win32_NetworkAdapterConfiguration(IPEnabled=True)
for nic in nics:
conf = interfaces.get(nic.MACAddress)
if conf:
yield nic, conf
def change_ip_windows(interfaces, dns):
for nic, conf in get_interfaces_windows(interfaces):
link_local = IPNetwork('fe80::/16')
new_addrs = [IPNetwork(ip) for ip in conf['addresses']]
new_addrs_str = set(str(ip) for ip in new_addrs)
old_addrs = [IPNetwork('%s/%s' % (ip, nic.IPSubnet[i]))
for i, ip in enumerate(nic.IPAddress)
if IPAddress(ip) not in link_local]
old_addrs_str = set(str(ip) for ip in old_addrs)
changed = (
new_addrs_str != old_addrs_str or
set(nic.DefaultIPGateway) != set([conf['gw4'], conf['gw6']]))
if changed or 1:
logger.info('new config for <%s(%s)>: %s', nic.Description,
nic.MACAddress, ', '.join(new_addrs_str))
# IPv4
ipv4_addrs = [str(ip.ip) for ip in new_addrs
if ip.version == 4]
ipv4_masks = [str(ip.netmask) for ip in new_addrs
if ip.version == 4]
logger.debug('<%s>.EnableStatic(%s, %s) called', nic.Description,
ipv4_addrs, ipv4_masks)
retval = nic.EnableStatic(
IPAddress=ipv4_addrs, SubnetMask=ipv4_masks)
assert retval == (0, )
nic.SetGateways(DefaultIPGateway=[conf['gw4']])
assert retval == (0, )
# IPv6
for ip in new_addrs:
if ip.version == 6 and str(ip) not in old_addrs_str:
logger.debug('add %s (%s)', ip, nic.Description)
check_output(
'netsh interface ipv6 add address '
'interface=%s address=%s'
% (nic.InterfaceIndex, ip), shell=True)
for ip in old_addrs:
if ip.version == 6 and str(ip) not in new_addrs_str:
logger.debug('delete %s (%s)', ip, nic.Description)
check_output(
'netsh interface ipv6 delete address '
'interface=%s address=%s'
% (nic.InterfaceIndex, ip.ip), shell=True)
try:
check_output('netsh interface ipv6 del route ::/0 interface=%s'
% nic.InterfaceIndex, shell=True)
except CalledProcessError:
pass
check_output('netsh interface ipv6 add route ::/0 interface=%s %s'
% (nic.InterfaceIndex, conf['gw6']), shell=True)
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Serial port support for Windows.
Requires PySerial and pywin32.
"""
# system imports
import win32file
import win32event
import win32con
# twisted imports
from twisted.internet import abstract
# sibling imports
import logging
logger = logging.getLogger()
class SerialPort(abstract.FileDescriptor):
"""A serial device, acting as a transport, that uses a win32 event."""
connected = 1
def __init__(self, protocol, deviceNameOrPortNumber, reactor):
self.hComPort = win32file.CreateFile(
deviceNameOrPortNumber,
win32con.GENERIC_READ | win32con.GENERIC_WRITE,
0, # exclusive access
None, # no security
win32con.OPEN_EXISTING,
win32con.FILE_ATTRIBUTE_NORMAL | win32con.FILE_FLAG_OVERLAPPED,
0)
self.reactor = reactor
self.protocol = protocol
self.outQueue = []
self.closed = 0
self.closedNotifies = 0
self.writeInProgress = 0
self.protocol = protocol
self._overlappedRead = win32file.OVERLAPPED()
self._overlappedRead.hEvent = win32event.CreateEvent(None, 1, 0, None)
self._overlappedWrite = win32file.OVERLAPPED()
self._overlappedWrite.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.reactor.addEvent(
self._overlappedRead.hEvent,
self,
'serialReadEvent')
self.reactor.addEvent(
self._overlappedWrite.hEvent,
self,
'serialWriteEvent')
self.protocol.makeConnection(self)
self._finishPortSetup()
def _finishPortSetup(self):
"""
Finish setting up the serial port.
This is a separate method to facilitate testing.
"""
rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1),
self._overlappedRead)
def serialReadEvent(self):
# get that character we set up
try:
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
0)
except:
import time
time.sleep(10)
n = 0
if n:
first = str(self.read_buf[:n])
# now we should get everything that is already in the buffer (max
# 4096)
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(4096),
self._overlappedRead)
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
1)
# handle all the received data:
self.protocol.dataReceived(first + str(buf[:n]))
# set up next one
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1),
self._overlappedRead)
def write(self, data):
if data:
if self.writeInProgress:
self.outQueue.append(data)
logger.debug("added to queue")
else:
self.writeInProgress = 1
win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file")
def serialWriteEvent(self):
try:
dataToWrite = self.outQueue.pop(0)
except IndexError:
self.writeInProgress = 0
return
else:
win32file.WriteFile(
self.hComPort,
dataToWrite,
self._overlappedWrite)
def connectionLost(self, reason):
"""
Called when the serial port disconnects.
Will call C{connectionLost} on the protocol that is handling the
serial data.
"""
self.reactor.removeEvent(self._overlappedRead.hEvent)
self.reactor.removeEvent(self._overlappedWrite.hEvent)
abstract.FileDescriptor.connectionLost(self, reason)
win32file.CloseHandle(self.hComPort)
self.protocol.connectionLost(reason)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment