Commit 56441d4b by Guba Sándor

refactor agent.py

parent e4baa24d
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from os import mkdir, environ, chdir from os import environ, chdir
import platform import platform
from shutil import copy from shutil import copy
import subprocess import subprocess
...@@ -9,12 +9,13 @@ import sys ...@@ -9,12 +9,13 @@ import sys
system = platform.system() system = platform.system()
try: if system == "Linux":
try:
chdir(sys.path[0]) chdir(sys.path[0])
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) subprocess.call(('pip', 'install', '-r', 'requirements.txt'))
if system == 'Linux': if system == 'Linux':
copy("/root/agent/misc/vm_renewal", "/usr/local/bin/") copy("/root/agent/misc/vm_renewal", "/usr/local/bin/")
except: except:
pass # hope it works pass # hope it works
...@@ -24,337 +25,26 @@ from twisted.internet.task import LoopingCall ...@@ -24,337 +25,26 @@ from twisted.internet.task import LoopingCall
import uptime import uptime
import logging import logging
import fileinput from os.path import exists
import tarfile
from os.path import expanduser, join, exists
from glob import glob
from inspect import getargspec, isfunction from inspect import getargspec, isfunction
from StringIO import StringIO
from base64 import decodestring
from shutil import rmtree, move
from datetime import datetime
from utils import SerialLineReceiverBase from utils import SerialLineReceiverBase
from ssh import PubKey from context import Context
from network import change_ip_ubuntu, change_ip_rhel, change_ip_windows
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger() logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO') level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level) logger.setLevel(level)
SSH_DIR = expanduser('~cloud/.ssh')
AUTHORIZED_KEYS = join(SSH_DIR, 'authorized_keys')
STORE_DIR = '/store'
mount_template_linux = (
'//%(host)s/%(username)s %(dir)s cifs username=%(username)s'
',password=%(password)s,iocharset=utf8,uid=cloud 0 0\n')
distros = {'Scientific Linux': 'rhel',
'CentOS': 'rhel',
'CentOS Linux': 'rhel',
'Debian': 'debian',
'Ubuntu': 'debian'}
if system == 'Linux':
distro = distros[platform.linux_distribution()[0]]
# http://stackoverflow.com/questions/12081310/
# python-module-to-change-system-date-and-time
def linux_set_time(time):
import ctypes
import ctypes.util
CLOCK_REALTIME = 0
class timespec(ctypes.Structure):
_fields_ = [("tv_sec", ctypes.c_long),
("tv_nsec", ctypes.c_long)]
librt = ctypes.CDLL(ctypes.util.find_library("rt"))
ts = timespec()
ts.tv_sec = int(time)
ts.tv_nsec = 0
librt.clock_settime(CLOCK_REALTIME, ctypes.byref(ts))
class Context(object):
@staticmethod
def change_password(password):
if system == 'Linux':
proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password)
elif system == 'Windows':
from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo()
ads_obj.SetPassword(password)
@staticmethod
def restart_networking():
if system == 'Linux':
if distro == 'debian':
subprocess.call(['/etc/init.d/networking', 'restart'])
elif distro == 'rhel':
subprocess.call(['/bin/systemctl', 'restart', 'network'])
pass
elif system == 'Windows':
pass
@staticmethod
def change_ip(interfaces, dns):
if system == 'Linux':
if distro == 'debian':
change_ip_ubuntu(interfaces, dns)
elif distro == 'rhel':
change_ip_rhel(interfaces, dns)
elif system == 'Windows':
change_ip_windows(interfaces, dns)
@staticmethod
def set_time(time):
if system == 'Linux':
linux_set_time(float(time))
try:
subprocess.call(['/etc/init.d/ntp', 'restart'])
except:
pass
elif system == 'Windows':
import win32api
t = datetime.utcfromtimestamp(float(time))
win32api.SetSystemTime(t.year, t.month, 0, t.day, t.hour,
t.minute, t.second, 0)
@staticmethod
def set_hostname(hostname):
if system == 'Linux':
if distro == 'debian':
with open('/etc/hostname', 'w') as f:
f.write(hostname)
elif distro == 'rhel':
for line in fileinput.input('/etc/sysconfig/network',
inplace=1):
if line.startswith('HOSTNAME='):
print 'HOSTNAME=%s' % hostname
else:
print line.rstrip()
with open('/etc/hosts', 'w') as f:
f.write("127.0.0.1 localhost\n"
"127.0.1.1 %s\n" % hostname)
subprocess.call(['/bin/hostname', hostname])
elif system == 'Windows':
import wmi
wmi.WMI().Win32_ComputerSystem()[0].Rename(hostname)
@staticmethod
def mount_store(host, username, password):
data = {'host': host, 'username': username, 'password': password}
if system == 'Linux':
data['dir'] = STORE_DIR
if not exists(STORE_DIR):
mkdir(STORE_DIR)
# TODO
for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip()
with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data)
subprocess.call('mount -a', shell=True)
elif system == 'Windows':
import notify
url = 'cifs://%s:%s@%s/%s' % (username, password, host, username)
for c in notify.clients:
logger.debug("sending url %s to client %s", url, unicode(c))
c.sendLine(url.encode())
@staticmethod
def get_keys():
retval = []
try:
with open(AUTHORIZED_KEYS, 'r') as f:
for line in f.readlines():
try:
retval.append(PubKey.from_str(line))
except:
logger.exception(u'Invalid ssh key: ')
except IOError:
pass
return retval
@staticmethod
def _save_keys(keys):
print keys
try:
mkdir(SSH_DIR)
except OSError:
pass
with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys:
f.write(unicode(key) + '\n')
@staticmethod
def add_keys(keys):
if system == 'Linux':
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
if p not in new_keys:
new_keys.append(p)
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def del_keys(keys):
if system == 'Linux':
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
try:
new_keys.remove(p)
except ValueError:
pass
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def cleanup():
if system == 'Linux':
filelist = ([
'/root/.bash_history'
'/home/cloud/.bash_history'
'/root/.ssh'
'/home/cloud/.ssh']
+ glob('/etc/ssh/ssh_host_*'))
for f in filelist:
rmtree(f, ignore_errors=True)
subprocess.call(('/usr/bin/ssh-keygen', '-A'))
elif system == 'Windows':
# TODO
pass
@staticmethod
def start_access_server():
if system == 'Linux':
try:
subprocess.call(('/sbin/start', 'ssh'))
except OSError:
subprocess.call(('/bin/systemctl', 'start', 'sshd.service'))
elif system == 'Windows':
# TODO
pass
@classmethod
def _update_linux(cls, data, uuid):
cur_dir = sys.path[0]
new_dir = cur_dir + '.new'
old_dir = cur_dir + '.old'
f = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=f, mode='r|gz')
tar.extractall(new_dir)
except tarfile.ReadError as e:
logger.error(e)
else:
rmtree(old_dir, ignore_errors=True)
move(cur_dir, old_dir)
move(new_dir, cur_dir)
logger.info('Updated')
reactor.stop()
@classmethod
def _update_windows(cls, data, executable, uuid):
# Extract the tar to the new path
cur_dir = sys.path[0]
new_dir = cur_dir + '.version'
f = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=f, mode='r|gz')
tar.extractall(new_dir)
except tarfile.ReadError as e:
logger.error(e)
else:
cls._update_registry(new_dir, executable)
logger.info('Updated')
reactor.stop()
@classmethod
def _update_registry(cls, dir, executable):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from _winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS)
with OpenKeyEx(HKEY_LOCAL_MACHINE,
r'SYSTEM\CurrentControlSet\services\circle-agent',
0,
KEY_ALL_ACCESS) as key:
(old_executable, reg_type) = QueryValueEx(key, "ImagePath")
SetValueEx(key, "ImagePath", None, 2, join(dir, executable))
return old_executable
@staticmethod
def update(data, executable, uuid):
if system == "Windows":
Context._update_windows(data, executable, uuid)
else:
Context._update_linux(data, executable, uuid)
@staticmethod
def ipaddresses():
import netifaces
args = {}
interfaces = netifaces.interfaces()
for i in interfaces:
if i == 'lo':
continue
args[i] = []
addresses = netifaces.ifaddresses(i)
args[i] = ([x['addr']
for x in addresses.get(netifaces.AF_INET, [])] +
[x['addr']
for x in addresses.get(netifaces.AF_INET6, [])
if '%' not in x['addr']])
return args
@staticmethod
def get_agent_version():
try:
with open('version.txt') as f:
return f.readline()
except IOError:
return None
@staticmethod
def send_expiration(url):
import notify
notify.notify(url)
class SerialLineReceiver(SerialLineReceiverBase): class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self): def connectionMade(self):
self.send_command( self.send_command(
command='agent_started', command='agent_started',
args={'version': Context.get_agent_version()}) args={'version': Context.get_agent_version(),
'system': system})
def shutdown(): def shutdown():
self.connectionLost2('shutdown') self.connectionLost2('shutdown')
...@@ -473,20 +163,22 @@ def _get_virtio_device(): ...@@ -473,20 +163,22 @@ def _get_virtio_device():
def main(): def main():
port = None
if system == 'Windows': if system == 'Windows':
port = _get_virtio_device() port = _get_virtio_device()
if port: if port:
from w32serial import SerialPort from context import SerialPort
else: else:
from twisted.internet.serial import SerialPort from twisted.internet.serial import SerialPort
import pythoncom import pythoncom
pythoncom.CoInitialize() pythoncom.CoInitialize()
port = r'\\.\COM1' port = r'\\.\COM1'
else: else:
from twisted.internet.serial import SerialPort
# Try virtio first
port = "/dev/virtio-ports/agent" port = "/dev/virtio-ports/agent"
if not exists(port): if exists(port):
from context import SerialPort
else:
from twisted.internet.serial import SerialPort
port = '/dev/ttyS0' port = '/dev/ttyS0'
logger.info("Opening port %s", port) logger.info("Opening port %s", port)
SerialPort(SerialLineReceiver(), port, reactor) SerialPort(SerialLineReceiver(), port, reactor)
...@@ -494,10 +186,10 @@ def main(): ...@@ -494,10 +186,10 @@ def main():
from notify import register_publisher from notify import register_publisher
register_publisher(reactor) register_publisher(reactor)
except: except:
logger.exception("Couldnt register notify publisher") logger.exception("Could not register notify publisher")
logger.debug("Starting reactor.") logger.debug("Starting reactor.")
reactor.run() reactor.run()
logger.debug("Reactor after run.") logger.debug("Reactor finished.")
if __name__ == '__main__': if __name__ == '__main__':
......
import platform
""" This is the defautl context file. It replaces the Context class
to the platform specific one.
"""
system = platform.system()
if system == "Windows":
from windows._win32context import Context
from win32.win32virtio import SerialPort
elif system == "Linux":
from linux._linuxcontext import Context
from linux.posixvirtio import SerialPort
else:
raise NotImplementedError("Platform %s is not supported.", system)
class BaseContext():
pass
Context
SerialPort
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os import mkdir, environ, chdir
import platform
from shutil import copy, rmtree, move
import subprocess
import sys
system = platform.system()
working_directory = sys.path[0]
try:
chdir(working_directory)
subprocess.call(('pip', 'install', '-r', 'requirements.txt'))
if system == 'Linux':
copy("/root/agent/misc/vm_renewal", "/usr/local/bin/")
except:
pass # hope it works
import logging
import fileinput
import tarfile
from os.path import expanduser, join, exists
from glob import glob
from StringIO import StringIO
from base64 import decodestring
from hashlib import md5
from ssh import PubKey
from network import change_ip_ubuntu, change_ip_rhel
from twisted.internet import reactor
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
SSH_DIR = expanduser('~cloud/.ssh')
AUTHORIZED_KEYS = join(SSH_DIR, 'authorized_keys')
STORE_DIR = '/store'
mount_template_linux = (
'//%(host)s/%(username)s %(dir)s cifs username=%(username)s'
',password=%(password)s,iocharset=utf8,uid=cloud 0 0\n')
distros = {'Scientific Linux': 'rhel',
'CentOS': 'rhel',
'CentOS Linux': 'rhel',
'Debian': 'debian',
'Ubuntu': 'debian'}
if system == 'Linux':
distro = distros[platform.linux_distribution()[0]]
class Context(object):
# http://stackoverflow.com/questions/12081310/
# python-module-to-change-system-date-and-time
def _linux_set_time(time):
import ctypes
import ctypes.util
CLOCK_REALTIME = 0
class timespec(ctypes.Structure):
_fields_ = [("tv_sec", ctypes.c_long),
("tv_nsec", ctypes.c_long)]
librt = ctypes.CDLL(ctypes.util.find_library("rt"))
ts = timespec()
ts.tv_sec = int(time)
ts.tv_nsec = 0
librt.clock_settime(CLOCK_REALTIME, ctypes.byref(ts))
@staticmethod
def change_password(password):
proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password)
@staticmethod
def restart_networking():
if distro == 'debian':
subprocess.call(['/etc/init.d/networking', 'restart'])
elif distro == 'rhel':
subprocess.call(['/bin/systemctl', 'restart', 'network'])
pass
@staticmethod
def change_ip(interfaces, dns):
if distro == 'debian':
change_ip_ubuntu(interfaces, dns)
elif distro == 'rhel':
change_ip_rhel(interfaces, dns)
@staticmethod
def set_time(time):
Context._linux_set_time(float(time))
try:
subprocess.call(['/etc/init.d/ntp', 'restart'])
except:
pass
@staticmethod
def set_hostname(hostname):
if distro == 'debian':
with open('/etc/hostname', 'w') as f:
f.write(hostname)
elif distro == 'rhel':
for line in fileinput.input('/etc/sysconfig/network',
inplace=1):
if line.startswith('HOSTNAME='):
print 'HOSTNAME=%s' % hostname
else:
print line.rstrip()
with open('/etc/hosts', 'w') as f:
f.write("127.0.0.1 localhost\n"
"127.0.1.1 %s\n" % hostname)
subprocess.call(['/bin/hostname', hostname])
@staticmethod
def mount_store(host, username, password):
data = {'host': host, 'username': username, 'password': password}
data['dir'] = STORE_DIR
if not exists(STORE_DIR):
mkdir(STORE_DIR)
# TODO
for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip()
with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data)
subprocess.call('mount -a', shell=True)
@staticmethod
def get_keys():
retval = []
try:
with open(AUTHORIZED_KEYS, 'r') as f:
for line in f.readlines():
try:
retval.append(PubKey.from_str(line))
except:
logger.exception(u'Invalid ssh key: ')
except IOError:
pass
return retval
@staticmethod
def _save_keys(keys):
print keys
try:
mkdir(SSH_DIR)
except OSError:
pass
with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys:
f.write(unicode(key) + '\n')
@staticmethod
def add_keys(keys):
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
if p not in new_keys:
new_keys.append(p)
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def del_keys(keys):
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
try:
new_keys.remove(p)
except ValueError:
pass
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def cleanup():
filelist = ([
'/root/.bash_history'
'/home/cloud/.bash_history'
'/root/.ssh'
'/home/cloud/.ssh']
+ glob('/etc/ssh/ssh_host_*'))
for f in filelist:
rmtree(f, ignore_errors=True)
subprocess.call(('/usr/bin/ssh-keygen', '-A'))
@staticmethod
def start_access_server():
try:
subprocess.call(('/sbin/start', 'ssh'))
except OSError:
subprocess.call(('/bin/systemctl', 'start', 'sshd.service'))
@staticmethod
def append(data, filename, chunk_number, uuid):
if chunk_number == 0:
flag = "w"
else:
flag = "a"
with open(filename, flag) as myfile:
myfile.write(data)
@staticmethod
def update(filename, executable, checksum, uuid):
new_dir = working_directory + '.new'
old_dir = working_directory + '.old'
with open(filename, "r") as f:
data = f.read()
local_checksum = md5(data).hexdigest()
if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(new_dir)
except tarfile.ReadError as e:
logger.error(e)
move(working_directory, old_dir)
move(new_dir, working_directory)
logger.info("Transfer completed!")
reactor.stop()
@staticmethod
def ipaddresses():
import netifaces
args = {}
interfaces = netifaces.interfaces()
for i in interfaces:
if i == 'lo':
continue
args[i] = []
addresses = netifaces.ifaddresses(i)
args[i] = ([x['addr']
for x in addresses.get(netifaces.AF_INET, [])] +
[x['addr']
for x in addresses.get(netifaces.AF_INET6, [])
if '%' not in x['addr']])
return args
@staticmethod
def get_agent_version():
try:
with open('version.txt') as f:
return f.readline()
except IOError:
return None
@staticmethod
def send_expiration(url):
import notify
notify.notify(url)
import netifaces
from netaddr import IPNetwork
import fileinput
import logging
logger = logging.getLogger()
interfaces_file = '/etc/network/interfaces'
ifcfg_template = '/etc/sysconfig/network-scripts/ifcfg-%s'
def get_interfaces_linux(interfaces):
for ifname in netifaces.interfaces():
mac = netifaces.ifaddresses(ifname)[17][0]['addr']
conf = interfaces.get(mac.upper())
if conf:
yield ifname, conf
def remove_interfaces_ubuntu(devices):
delete_device = False
for line in fileinput.input(interfaces_file, inplace=True):
line = line.rstrip()
words = line.split()
if line.startswith('#') or line == '' or line.isspace() or not words:
# keep line
print line
continue
if (words[0] in ('auto', 'allow-hotplug') and
words[1].split(':')[0] in devices):
# remove line
continue
if words[0] == 'iface':
if words[1].split(':')[0] in devices:
# remove line
delete_device = True
continue
else:
delete_device = False
if line[0] in (' ', '\t') and delete_device:
# remove line
continue
# keep line
print line
def change_ip_ubuntu(interfaces, dns):
data = list(get_interfaces_linux(interfaces))
remove_interfaces_ubuntu(dict(data).keys())
with open(interfaces_file, 'a') as f:
for ifname, conf in data:
ipv4_alias_counter = ipv6_alias_counter = 0
f.write('auto %s\n' % ifname)
for i in conf['addresses']:
ip_with_prefix = IPNetwork(i)
prefixlen = ip_with_prefix.prefixlen
ip = ip_with_prefix.ip
alias = ifname
if ip.version == 6:
if ipv6_alias_counter > 0:
alias = '%s:%d' % (ifname, ipv6_alias_counter)
ipv6_alias_counter += 1
else:
if ipv4_alias_counter > 0:
alias = '%s:%d' % (ifname, ipv4_alias_counter)
ipv4_alias_counter += 1
f.write(
'iface %(ifname)s %(proto)s static\n'
' address %(ip)s\n'
' netmask %(prefixlen)d\n'
' gateway %(gw)s\n'
' dns-nameservers %(dns)s\n' % {
'ifname': alias,
'proto': 'inet6' if ip.version == 6 else 'inet',
'ip': ip,
'prefixlen': prefixlen,
'gw': conf['gw6' if ip.version == 6 else 'gw4'],
'dns': dns})
# example:
# change_ip_ubuntu({
# u'02:00:00:02:A3:E8': {
# u'gw4': u'10.1.0.254', 'gw6': '2001::ffff',
# u'addresses': [u'10.1.0.84/24', '10.1.0.1/24', '2001::1/48']},
# u'02:00:00:02:A3:E9': {
# u'gw4': u'10.255.255.1', u'addresses': [u'10.255.255.9']}},
# '8.8.8.8')
def change_ip_rhel(interfaces, dns):
for ifname, conf in get_interfaces_linux(interfaces):
with open(ifcfg_template % ifname,
'w') as f:
f.write('DEVICE=%s\n'
'BOOTPROTO=none\n'
'USERCTL=no\n'
'ONBOOT=yes\n' % ifname)
for i in conf['addresses']:
ip_with_prefix = IPNetwork(i)
ip = ip_with_prefix.ip
if ip.version == 6:
f.write('IPV6INIT=yes\n'
'IPV6ADDR=%(ip)s/%(prefixlen)d\n'
'IPV6_DEFAULTGW=%(gw)s\n' % {
'ip': ip,
'prefixlen': ip_with_prefix.prefixlen,
'gw': conf['gw6']})
else:
f.write('NETMASK=%(netmask)s\n'
'IPADDR=%(ip)s\n'
'GATEWAY=%(gw)s\n' % {
'ip': ip,
'netmask': str(ip_with_prefix.netmask),
'gw': conf['gw4']})
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Virtio-Serial Port Protocol
"""
# system imports
import os
# dependent on pyserial ( http://pyserial.sf.net/ )
# only tested w/ 1.18 (5 Dec 2002)
# twisted imports
from twisted.internet import abstract, fdesc
class SerialPort(abstract.FileDescriptor):
"""
A select()able serial device, acting as a transport.
"""
connected = 1
def __init__(self, protocol, deviceNameOrPortNumber, reactor):
abstract.FileDescriptor.__init__(self, reactor)
self.port = deviceNameOrPortNumber
self._serial = os.open(
self.port, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK)
self.reactor = reactor
self.protocol = protocol
self.protocol.makeConnection(self)
self.startReading()
def fileno(self):
return self._serial
def writeSomeData(self, data):
"""
Write some data to the serial device.
"""
return fdesc.writeToFD(self.fileno(), data)
def doRead(self):
"""
Some data's readable from serial device.
"""
return fdesc.readFromFD(self.fileno(), self.protocol.dataReceived)
def connectionLost(self, reason):
"""
Called when the serial port disconnects.
Will call C{connectionLost} on the protocol that is handling the
serial data.
"""
abstract.FileDescriptor.connectionLost(self, reason)
os.close(self._serial)
self.protocol.connectionLost(reason)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from os import mkdir, environ, chdir
import platform
from shutil import copy
import subprocess
import sys
system = platform.system()
try:
chdir(sys.path[0])
subprocess.call(('pip', 'install', '-r', 'requirements.txt'))
if system == 'Linux':
copy("/root/agent/misc/vm_renewal", "/usr/local/bin/")
except:
pass # hope it works
from twisted.internet import reactor, defer
from twisted.internet.task import LoopingCall
import uptime
import logging
import fileinput
import tarfile
from os.path import expanduser, join, exists
from glob import glob
from inspect import getargspec, isfunction
from StringIO import StringIO
from base64 import decodestring
from hashlib import md5
from shutil import rmtree, move
from datetime import datetime
from utils import SerialLineReceiverBase
from ssh import PubKey
from network import change_ip_ubuntu, change_ip_rhel, change_ip_windows
logging.basicConfig()
logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
SSH_DIR = expanduser('~cloud/.ssh')
AUTHORIZED_KEYS = join(SSH_DIR, 'authorized_keys')
STORE_DIR = '/store'
mount_template_linux = (
'//%(host)s/%(username)s %(dir)s cifs username=%(username)s'
',password=%(password)s,iocharset=utf8,uid=cloud 0 0\n')
distros = {'Scientific Linux': 'rhel',
'CentOS': 'rhel',
'CentOS Linux': 'rhel',
'Debian': 'debian',
'Ubuntu': 'debian'}
if system == 'Linux':
distro = distros[platform.linux_distribution()[0]]
# http://stackoverflow.com/questions/12081310/
# python-module-to-change-system-date-and-time
def linux_set_time(time):
import ctypes
import ctypes.util
CLOCK_REALTIME = 0
class timespec(ctypes.Structure):
_fields_ = [("tv_sec", ctypes.c_long),
("tv_nsec", ctypes.c_long)]
librt = ctypes.CDLL(ctypes.util.find_library("rt"))
ts = timespec()
ts.tv_sec = int(time)
ts.tv_nsec = 0
librt.clock_settime(CLOCK_REALTIME, ctypes.byref(ts))
class Context(object):
@staticmethod
def change_password(password):
if system == 'Linux':
proc = subprocess.Popen(['/usr/sbin/chpasswd'],
stdin=subprocess.PIPE)
proc.communicate('cloud:%s\n' % password)
elif system == 'Windows':
from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo()
ads_obj.SetPassword(password)
@staticmethod
def restart_networking():
if system == 'Linux':
if distro == 'debian':
subprocess.call(['/etc/init.d/networking', 'restart'])
elif distro == 'rhel':
subprocess.call(['/bin/systemctl', 'restart', 'network'])
pass
elif system == 'Windows':
pass
@staticmethod
def change_ip(interfaces, dns):
if system == 'Linux':
if distro == 'debian':
change_ip_ubuntu(interfaces, dns)
elif distro == 'rhel':
change_ip_rhel(interfaces, dns)
elif system == 'Windows':
change_ip_windows(interfaces, dns)
@staticmethod
def set_time(time):
if system == 'Linux':
linux_set_time(float(time))
try:
subprocess.call(['/etc/init.d/ntp', 'restart'])
except:
pass
elif system == 'Windows':
import win32api
t = datetime.utcfromtimestamp(float(time))
win32api.SetSystemTime(t.year, t.month, 0, t.day, t.hour,
t.minute, t.second, 0)
@staticmethod
def set_hostname(hostname):
if system == 'Linux':
if distro == 'debian':
with open('/etc/hostname', 'w') as f:
f.write(hostname)
elif distro == 'rhel':
for line in fileinput.input('/etc/sysconfig/network',
inplace=1):
if line.startswith('HOSTNAME='):
print 'HOSTNAME=%s' % hostname
else:
print line.rstrip()
with open('/etc/hosts', 'w') as f:
f.write("127.0.0.1 localhost\n"
"127.0.1.1 %s\n" % hostname)
subprocess.call(['/bin/hostname', hostname])
elif system == 'Windows':
import wmi
wmi.WMI().Win32_ComputerSystem()[0].Rename(hostname)
@staticmethod
def mount_store(host, username, password):
data = {'host': host, 'username': username, 'password': password}
if system == 'Linux':
data['dir'] = STORE_DIR
if not exists(STORE_DIR):
mkdir(STORE_DIR)
# TODO
for line in fileinput.input('/etc/fstab', inplace=True):
if not (line.startswith('//') and ' cifs ' in line):
print line.rstrip()
with open('/etc/fstab', 'a') as f:
f.write(mount_template_linux % data)
subprocess.call('mount -a', shell=True)
elif system == 'Windows':
import notify
url = 'cifs://%s:%s@%s/%s' % (username, password, host, username)
for c in notify.clients:
logger.debug("sending url %s to client %s", url, unicode(c))
c.sendLine(url.encode())
@staticmethod
def get_keys():
retval = []
try:
with open(AUTHORIZED_KEYS, 'r') as f:
for line in f.readlines():
try:
retval.append(PubKey.from_str(line))
except:
logger.exception(u'Invalid ssh key: ')
except IOError:
pass
return retval
@staticmethod
def _save_keys(keys):
print keys
try:
mkdir(SSH_DIR)
except OSError:
pass
with open(AUTHORIZED_KEYS, 'w') as f:
for key in keys:
f.write(unicode(key) + '\n')
@staticmethod
def add_keys(keys):
if system == 'Linux':
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
if p not in new_keys:
new_keys.append(p)
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def del_keys(keys):
if system == 'Linux':
new_keys = Context.get_keys()
for key in keys:
try:
p = PubKey.from_str(key)
try:
new_keys.remove(p)
except ValueError:
pass
except:
logger.exception(u'Invalid ssh key: ')
Context._save_keys(new_keys)
@staticmethod
def cleanup():
if system == 'Linux':
filelist = ([
'/root/.bash_history'
'/home/cloud/.bash_history'
'/root/.ssh'
'/home/cloud/.ssh']
+ glob('/etc/ssh/ssh_host_*'))
for f in filelist:
rmtree(f, ignore_errors=True)
subprocess.call(('/usr/bin/ssh-keygen', '-A'))
elif system == 'Windows':
# TODO
pass
@staticmethod
def start_access_server():
if system == 'Linux':
try:
subprocess.call(('/sbin/start', 'ssh'))
except OSError:
subprocess.call(('/bin/systemctl', 'start', 'sshd.service'))
elif system == 'Windows':
# TODO
pass
@classmethod
def _update_linux(cls, data, uuid):
cur_dir = sys.path[0]
new_dir = cur_dir + '.new'
old_dir = cur_dir + '.old'
f = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=f, mode='r|gz')
tar.extractall(new_dir)
except tarfile.ReadError as e:
logger.error(e)
else:
rmtree(old_dir, ignore_errors=True)
move(cur_dir, old_dir)
move(new_dir, cur_dir)
logger.info('Updated')
reactor.stop()
@classmethod
def _update_windows(cls, data, executable, uuid):
# Extract the tar to the new path
cur_dir = sys.path[0]
new_dir = cur_dir + '.version'
f = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=f, mode='r|gz')
tar.extractall(new_dir)
except tarfile.ReadError as e:
logger.error(e)
else:
cls._update_registry(new_dir, executable)
logger.info('Updated')
reactor.stop()
@classmethod
def _update_registry(cls, dir, executable):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from _winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS)
with OpenKeyEx(HKEY_LOCAL_MACHINE,
r'SYSTEM\CurrentControlSet\services\circle-agent',
0,
KEY_ALL_ACCESS) as key:
(old_executable, reg_type) = QueryValueEx(key, "ImagePath")
SetValueEx(key, "ImagePath", None, 2, join(dir, executable))
return old_executable
@staticmethod
def append(data, filename, chunk_number, uuid):
if chunk_number == 0:
flag = "w"
else:
flag = "a"
with open(filename, flag) as myfile:
myfile.write(data)
@staticmethod
def append_end(filename, checksum, uuid):
with open(filename, "r") as f:
data = f.read()
local_checksum = md5(data).hexdigest()
if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall("/tmp")
except tarfile.ReadError as e:
logger.error(e)
logger.info("Transfer completed!")
@staticmethod
def update(data, executable, uuid):
if system == "Windows":
Context._update_windows(data, executable, uuid)
else:
Context._update_linux(data, executable, uuid)
@staticmethod
def ipaddresses():
import netifaces
args = {}
interfaces = netifaces.interfaces()
for i in interfaces:
if i == 'lo':
continue
args[i] = []
addresses = netifaces.ifaddresses(i)
args[i] = ([x['addr']
for x in addresses.get(netifaces.AF_INET, [])] +
[x['addr']
for x in addresses.get(netifaces.AF_INET6, [])
if '%' not in x['addr']])
return args
@staticmethod
def get_agent_version():
try:
with open('version.txt') as f:
return f.readline()
except IOError:
return None
@staticmethod
def send_expiration(url):
import notify
notify.notify(url)
class SerialLineReceiver(SerialLineReceiverBase):
def connectionMade(self):
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(),
'system': system})
def shutdown():
self.connectionLost2('shutdown')
d = defer.Deferred()
reactor.callLater(0.3, d.callback, "1")
return d
reactor.addSystemEventTrigger("before", "shutdown", shutdown)
def connectionLost2(self, reason):
self.send_command(command='agent_stopped',
args={})
def tick(self):
logger.debug("Sending tick")
try:
self.send_status()
except:
logger.exception("Twisted hide exception")
def __init__(self):
super(SerialLineReceiver, self).__init__()
self.lc = LoopingCall(self.tick)
self.lc.start(5, now=False)
def send_status(self):
import psutil
disk_usage = {(disk.device.replace('/', '_')):
psutil.disk_usage(disk.mountpoint).percent
for disk in psutil.disk_partitions()}
args = {"cpu": dict(psutil.cpu_times()._asdict()),
"ram": dict(psutil.virtual_memory()._asdict()),
"swap": dict(psutil.swap_memory()._asdict()),
"uptime": {"seconds": uptime.uptime()},
"disk": disk_usage,
"user": {"count": len(psutil.get_users())}}
self.send_response(response='status',
args=args)
def _check_args(self, func, args):
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
(self._pretty_fun(func), type(args).__name__))
# check for unexpected keyword arguments
argspec = getargspec(func)
if argspec.keywords is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs:
raise TypeError(
"Command %s got unexpected keyword arguments: %s" % (
self._pretty_fun(func), ", ".join(unexpected_kwargs)))
mandatory_args = argspec.args
if argspec.defaults: # remove those with default value
mandatory_args = mandatory_args[0:-len(argspec.defaults)]
missing_kwargs = set(mandatory_args) - set(args)
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
def _get_command(self, command, args):
if not isinstance(command, basestring) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command)
try:
func = getattr(Context, command)
except AttributeError as e:
raise AttributeError(u'Command not found: %s (%s)' % (command, e))
if not isfunction(func):
raise AttributeError("Command refers to non-static method %s." %
self._pretty_fun(func))
self._check_args(func, args)
return func
@staticmethod
def _pretty_fun(fun):
try:
argspec = getargspec(fun)
args = argspec.args
if argspec.varargs:
args.append("*" + argspec.varargs)
if argspec.keywords:
args.append("**" + argspec.keywords)
return "%s(%s)" % (fun.__name__, ",".join(args))
except:
return "<%s>" % type(fun).__name__
def handle_command(self, command, args):
func = self._get_command(command, args)
retval = func(**args)
self.send_response(
response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)})
def handle_response(self, response, args):
pass
def _get_virtio_device():
path = None
GUID = '{6FDE7521-1B65-48ae-B628-80BE62016026}'
from infi.devicemanager import DeviceManager
dm = DeviceManager()
dm.root.rescan()
# Search Virtio-Serial by name TODO: search by class_guid
for i in dm.all_devices:
if i.has_property("description"):
if "virtio-serial".upper() in i.description.upper():
path = ("\\\\?\\" +
i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower()
)
return path
def main():
if system == 'Windows':
port = _get_virtio_device()
if port:
from w32serial import SerialPort
else:
from twisted.internet.serial import SerialPort
import pythoncom
pythoncom.CoInitialize()
port = r'\\.\COM1'
else:
#from twisted.internet.serial import SerialPort
# Try virtio first
from posixvirtio import SerialPort
port = "/dev/virtio-ports/agent"
if not exists(port):
port = '/dev/ttyS0'
logger.info("Opening port %s", port)
SerialPort(SerialLineReceiver(), port, reactor)
try:
from notify import register_publisher
register_publisher(reactor)
except:
logger.exception("Couldnt register notify publisher")
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor after run.")
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment