Commit 6a91b611 by Guba Sándor

vmdriver: static connection class and fixing pep issues.

parent a99cac51
""" Driver for libvirt. """
import libvirt import libvirt
import logging import logging
import os import os
...@@ -8,7 +9,6 @@ from vmcelery import celery, lib_connection ...@@ -8,7 +9,6 @@ from vmcelery import celery, lib_connection
sys.path.append(os.path.dirname(os.path.basename(__file__))) sys.path.append(os.path.dirname(os.path.basename(__file__)))
connection = None
state_dict = {0: 'NOSTATE', state_dict = {0: 'NOSTATE',
1: 'RUNNING', 1: 'RUNNING',
...@@ -21,16 +21,51 @@ state_dict = {0: 'NOSTATE', ...@@ -21,16 +21,51 @@ state_dict = {0: 'NOSTATE',
} }
#class Singleton(type):
#
# """ Singleton class."""
#
# _instances = {}
#
# def __call__(cls, *args, **kwargs):
# if cls not in cls._instances:
# cls._instances[cls] = super(Singleton, cls).__call__(*args,
# **kwargs)
# return cls._instances[cls]
class Connection(object):
""" Singleton class to handle connection."""
# __metaclass__ = Singleton
connection = None
@classmethod
def get(cls):
""" Return the libvirt connection."""
return cls.connection
@classmethod
def set(cls, connection):
""" Set the libvirt connection."""
cls.connection = connection
@decorator @decorator
def req_connection(original_function, *args, **kw): def req_connection(original_function, *args, **kw):
'''Connection checking decorator for libvirt. """Connection checking decorator for libvirt.
If envrionment variable LIBVIRT_KEEPALIVE is set If envrionment variable LIBVIRT_KEEPALIVE is set
it will use the connection from the celery worker. it will use the connection from the celery worker.
'''
Return the decorateed function
"""
logging.debug("Decorator running") logging.debug("Decorator running")
global connection if Connection.get() is None:
if connection is None:
connect() connect()
try: try:
logging.debug("Decorator calling original function") logging.debug("Decorator calling original function")
...@@ -48,6 +83,11 @@ def req_connection(original_function, *args, **kw): ...@@ -48,6 +83,11 @@ def req_connection(original_function, *args, **kw):
@decorator @decorator
def wrap_libvirtError(original_function, *args, **kw): def wrap_libvirtError(original_function, *args, **kw):
""" Decorator to wrap libvirt error in simple Exception.
Return decorated function
"""
try: try:
return original_function(*args, **kw) return original_function(*args, **kw)
except libvirt.libvirtError as e: except libvirt.libvirtError as e:
...@@ -60,34 +100,34 @@ def wrap_libvirtError(original_function, *args, **kw): ...@@ -60,34 +100,34 @@ def wrap_libvirtError(original_function, *args, **kw):
@celery.task @celery.task
@wrap_libvirtError @wrap_libvirtError
def connect(connection_string='qemu:///system'): def connect(connection_string='qemu:///system'):
'''Connect to the libvirt daemon specified in the """ Connect to the libvirt daemon.
connection_string or the local root.
''' String is specified in the connection_string parameter
global connection the default is the local root.
"""
if os.getenv('LIBVIRT_KEEPALIVE') is None: if os.getenv('LIBVIRT_KEEPALIVE') is None:
if connection is None: if Connection.get() is None:
connection = libvirt.open(connection_string) Connection.set(libvirt.open(connection_string))
logging.debug("Connection estabilished to libvirt.") logging.debug("Connection estabilished to libvirt.")
else: else:
logging.debug("There is already an active connection to libvirt.") logging.debug("There is already an active connection to libvirt.")
else: else:
connection = lib_connection Connection.set(lib_connection)
logging.debug("Using celery libvirt connection connection.") logging.debug("Using celery libvirt connection connection.")
@celery.task @celery.task
@wrap_libvirtError @wrap_libvirtError
def disconnect(): def disconnect():
'''Disconnect from the active libvirt daemon connection. """ Disconnect from the active libvirt daemon connection."""
'''
global connection
if os.getenv('LIBVIRT_KEEPALIVE') is None: if os.getenv('LIBVIRT_KEEPALIVE') is None:
if connection is None: if Connection.get() is None:
logging.debug('There is no available libvirt conection.') logging.debug('There is no available libvirt conection.')
else: else:
connection.close() Connection.get().close()
logging.debug('Connection closed to libvirt.') logging.debug('Connection closed to libvirt.')
connection = None Connection.set(None)
else: else:
logging.debug('Keepalive connection should not close.') logging.debug('Keepalive connection should not close.')
...@@ -96,9 +136,8 @@ def disconnect(): ...@@ -96,9 +136,8 @@ def disconnect():
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def define(vm): def define(vm):
'''Define permanent virtual machine from xml """ Define permanent virtual machine from xml. """
''' Connection.get().defineXML(vm.dump_xml())
connection.defineXML(vm.dump_xml())
logging.info("Virtual machine %s is defined from xml", vm.name) logging.info("Virtual machine %s is defined from xml", vm.name)
...@@ -106,14 +145,17 @@ def define(vm): ...@@ -106,14 +145,17 @@ def define(vm):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def create(vm_desc): def create(vm_desc):
'''Create and start non-permanent virtual machine from xml """ Create and start non-permanent virtual machine from xml.
Return the domain info dict.
flags can be: flags can be:
VIR_DOMAIN_NONE = 0 VIR_DOMAIN_NONE = 0
VIR_DOMAIN_START_PAUSED = 1 VIR_DOMAIN_START_PAUSED = 1
VIR_DOMAIN_START_AUTODESTROY = 2 VIR_DOMAIN_START_AUTODESTROY = 2
VIR_DOMAIN_START_BYPASS_CACHE = 4 VIR_DOMAIN_START_BYPASS_CACHE = 4
VIR_DOMAIN_START_FORCE_BOOT = 8 VIR_DOMAIN_START_FORCE_BOOT = 8
'''
"""
vm = VMInstance.deserialize(vm_desc) vm = VMInstance.deserialize(vm_desc)
# Setting proper hypervisor # Setting proper hypervisor
vm.vm_type = os.getenv("HYPERVISOR_TYPE", "test") vm.vm_type = os.getenv("HYPERVISOR_TYPE", "test")
...@@ -121,13 +163,13 @@ def create(vm_desc): ...@@ -121,13 +163,13 @@ def create(vm_desc):
logging.info(xml) logging.info(xml)
# Emulating DOMAIN_START_PAUSED FLAG behaviour on test driver # Emulating DOMAIN_START_PAUSED FLAG behaviour on test driver
if vm.vm_type == "test": if vm.vm_type == "test":
connection.createXML( Connection.get().createXML(
xml, libvirt.VIR_DOMAIN_NONE) xml, libvirt.VIR_DOMAIN_NONE)
domain = lookupByName(vm.name) domain = lookupByName(vm.name)
domain.suspend() domain.suspend()
# Real driver create # Real driver create
else: else:
connection.createXML( Connection.get().createXML(
vm.dump_xml(), libvirt.VIR_DOMAIN_START_PAUSED) vm.dump_xml(), libvirt.VIR_DOMAIN_START_PAUSED)
logging.info("Virtual machine %s is created from xml", vm.name) logging.info("Virtual machine %s is created from xml", vm.name)
return domain_info(vm.name) return domain_info(vm.name)
...@@ -137,19 +179,17 @@ def create(vm_desc): ...@@ -137,19 +179,17 @@ def create(vm_desc):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def shutdown(name): def shutdown(name):
'''Shutdown virtual machine (need ACPI support). """ Shutdown virtual machine (need ACPI support). """
'''
domain = lookupByName(name) domain = lookupByName(name)
domain.shutdown() domain.shutdown()
return _parse_info(domain.info())
@celery.task @celery.task
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def delete(name): def delete(name):
'''Destroy the running called 'name' virtual machine. """ Destroy the running called 'name' virtual machine. """
'''
domain = lookupByName(name) domain = lookupByName(name)
domain.destroy() domain.destroy()
...@@ -158,12 +198,14 @@ def delete(name): ...@@ -158,12 +198,14 @@ def delete(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def list_domains(): def list_domains():
''' """ List the running domains.
:return list: List of domains name in host
''' :return list: List of domains name in host.
"""
domain_list = [] domain_list = []
for i in connection.listDomainsID(): for i in Connection.get().listDomainsID():
dom = connection.lookupByID(i) dom = Connection.get().lookupByID(i)
domain_list.append(dom.name()) domain_list.append(dom.name())
return domain_list return domain_list
...@@ -172,18 +214,19 @@ def list_domains(): ...@@ -172,18 +214,19 @@ def list_domains():
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def lookupByName(name): def lookupByName(name):
'''Return with the requested Domain """ Return with the requested Domain. """
''' return Connection.get().lookupByName(name)
return connection.lookupByName(name)
@celery.task @celery.task
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def undefine(name): def undefine(name):
'''Undefine an already defined virtual machine. """ Undefine an already defined virtual machine.
If it's running it becomes transient (lsot on reboot)
''' If it's running it becomes transient (lost on reboot)
"""
domain = lookupByName(name) domain = lookupByName(name)
domain.undefine() domain.undefine()
...@@ -192,8 +235,8 @@ def undefine(name): ...@@ -192,8 +235,8 @@ def undefine(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def start(name): def start(name):
'''Start an already defined virtual machine. """ Start an already defined virtual machine."""
'''
domain = lookupByName(name) domain = lookupByName(name)
domain.create() domain.create()
...@@ -202,8 +245,12 @@ def start(name): ...@@ -202,8 +245,12 @@ def start(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def suspend(name): def suspend(name):
'''Stop virtual machine and keep memory in RAM. """ Stop virtual machine and keep memory in RAM.
'''
Return the domain info dict.
"""
domain = lookupByName(name) domain = lookupByName(name)
domain.suspend() domain.suspend()
return _parse_info(domain.info()) return _parse_info(domain.info())
...@@ -213,8 +260,8 @@ def suspend(name): ...@@ -213,8 +260,8 @@ def suspend(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def save(name, path): def save(name, path):
'''Stop virtual machine and save its memory to path. """ Stop virtual machine and save its memory to path. """
'''
domain = lookupByName(name) domain = lookupByName(name)
domain.save(path) domain.save(path)
...@@ -223,9 +270,14 @@ def save(name, path): ...@@ -223,9 +270,14 @@ def save(name, path):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def restore(name, path): def restore(name, path):
'''Restore a saved virtual machine """ Restore a saved virtual machine.
from the memory image stored at path.'''
connection.restore(path) Restores the virtual machine from the memory image
stored at path.
Return the domain info dict.
"""
Connection.get().restore(path)
return domain_info(name) return domain_info(name)
...@@ -233,8 +285,12 @@ def restore(name, path): ...@@ -233,8 +285,12 @@ def restore(name, path):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def resume(name): def resume(name):
'''Resume stopped virtual machines. """ Resume stopped virtual machines.
'''
Return the domain info dict.
"""
domain = lookupByName(name) domain = lookupByName(name)
domain.resume() domain.resume()
return _parse_info(domain.info()) return _parse_info(domain.info())
...@@ -244,8 +300,12 @@ def resume(name): ...@@ -244,8 +300,12 @@ def resume(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def reset(name): def reset(name):
'''Reset (power reset) virtual machine. """ Reset (power reset) virtual machine.
'''
Return the domain info dict.
"""
domain = lookupByName(name) domain = lookupByName(name)
domain.reset() domain.reset()
return _parse_info(domain.info()) return _parse_info(domain.info())
...@@ -255,8 +315,11 @@ def reset(name): ...@@ -255,8 +315,11 @@ def reset(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def reboot(name): def reboot(name):
'''Reboot (with guest acpi support) virtual machine. """ Reboot (with guest acpi support) virtual machine.
'''
Return the domain info dict.
"""
domain = lookupByName(name) domain = lookupByName(name)
domain.reboot() domain.reboot()
return _parse_info(domain.info()) return _parse_info(domain.info())
...@@ -266,7 +329,10 @@ def reboot(name): ...@@ -266,7 +329,10 @@ def reboot(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def node_info(): def node_info():
''' Get info from Host as dict: """ Get info from Host as dict.
Return dict:
model string indicating the CPU model model string indicating the CPU model
memory memory size in kilobytes memory memory size in kilobytes
cpus the number of active CPUs cpus the number of active CPUs
...@@ -279,15 +345,22 @@ def node_info(): ...@@ -279,15 +345,22 @@ def node_info():
cores number of cores per socket, total number of cores number of cores per socket, total number of
processors in case of unusual NUMA topolog processors in case of unusual NUMA topolog
threads number of threads per core, 1 in case of unusual numa topology threads number of threads per core, 1 in case of unusual numa topology
'''
"""
keys = ['model', 'memory', 'cpus', 'mhz', keys = ['model', 'memory', 'cpus', 'mhz',
'nodes', 'sockets', 'cores', 'threads'] 'nodes', 'sockets', 'cores', 'threads']
values = connection.getInfo() values = Connection.get().getInfo()
return dict(zip(keys, values)) return dict(zip(keys, values))
def _parse_info(values): def _parse_info(values):
'''Parse libvirt domain info into dict''' """ Parse libvirt domain info into dict.
Return the info dict.
"""
keys = ['state', 'maxmem', 'memory', 'virtcpunum', 'cputime'] keys = ['state', 'maxmem', 'memory', 'virtcpunum', 'cputime']
info = dict(zip(keys, values)) info = dict(zip(keys, values))
# Change state to proper ENUM # Change state to proper ENUM
...@@ -299,13 +372,16 @@ def _parse_info(values): ...@@ -299,13 +372,16 @@ def _parse_info(values):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def domain_info(name): def domain_info(name):
''' """ Get the domain info from libvirt.
Return the domain info dict:
state the running state, one of virDomainState state the running state, one of virDomainState
maxmem the maximum memory in KBytes allowed maxmem the maximum memory in KBytes allowed
memory the memory in KBytes used by the domain memory the memory in KBytes used by the domain
virtcpunum the number of virtual CPUs for the domain virtcpunum the number of virtual CPUs for the domain
cputime the CPU time used in nanoseconds cputime the CPU time used in nanoseconds
'''
"""
dom = lookupByName(name) dom = lookupByName(name)
return _parse_info(dom.info()) return _parse_info(dom.info())
...@@ -314,7 +390,8 @@ def domain_info(name): ...@@ -314,7 +390,8 @@ def domain_info(name):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def network_info(name, network): def network_info(name, network):
''' """ Return the network info dict.
rx_bytes rx_bytes
rx_packets rx_packets
rx_errs rx_errs
...@@ -323,7 +400,8 @@ def network_info(name, network): ...@@ -323,7 +400,8 @@ def network_info(name, network):
tx_packets tx_packets
tx_errs tx_errs
tx_drop tx_drop
'''
"""
keys = ['rx_bytes', 'rx_packets', 'rx_errs', 'rx_drop', keys = ['rx_bytes', 'rx_packets', 'rx_errs', 'rx_drop',
'tx_bytes', 'tx_packets', 'tx_errs', 'tx_drop'] 'tx_bytes', 'tx_packets', 'tx_errs', 'tx_drop']
dom = lookupByName(name) dom = lookupByName(name)
...@@ -336,10 +414,12 @@ def network_info(name, network): ...@@ -336,10 +414,12 @@ def network_info(name, network):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def send_key(name, key_code): def send_key(name, key_code):
''' Sending linux key_code to the name vm """ Sending linux key_code to the name vm.
key_code can be optained from linux_keys.py key_code can be optained from linux_keys.py
e.x: linuxkeys.KEY_RIGHTCTRL e.x: linuxkeys.KEY_RIGHTCTRL
'''
"""
domain = lookupByName(name) domain = lookupByName(name)
domain.sendKey(libvirt.VIR_KEYCODE_SET_LINUX, 100, [key_code], 1, 0) domain.sendKey(libvirt.VIR_KEYCODE_SET_LINUX, 100, [key_code], 1, 0)
...@@ -353,19 +433,20 @@ def _stream_handler(stream, buf, opaque): ...@@ -353,19 +433,20 @@ def _stream_handler(stream, buf, opaque):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def screenshot(name, path): def screenshot(name, path):
"""Save screenshot of virtual machine """Save screenshot of virtual machine.
to the path as name-screenshot.ppm
Image is saved to the path as name-screenshot.ppm
""" """
# Import linuxkeys to get defines # Import linuxkeys to get defines
import linuxkeys import linuxkeys
# Connection need for the stream object # Connection need for the stream object
global connection
domain = lookupByName(name) domain = lookupByName(name)
# Send key to wake up console # Send key to wake up console
domain.sendKey(libvirt.VIR_KEYCODE_SET_LINUX, domain.sendKey(libvirt.VIR_KEYCODE_SET_LINUX,
100, [linuxkeys.KEY_RIGHTCTRL], 1, 0) 100, [linuxkeys.KEY_RIGHTCTRL], 1, 0)
# Create Stream to get data # Create Stream to get data
stream = connection.newStream(0) stream = Connection.get().newStream(0)
# Take screenshot accessible by stream (return mimetype) # Take screenshot accessible by stream (return mimetype)
domain.screenshot(stream, 0, 0) domain.screenshot(stream, 0, 0)
# Get file to save data (TODO: send on AMQP?) # Get file to save data (TODO: send on AMQP?)
...@@ -383,7 +464,7 @@ def screenshot(name, path): ...@@ -383,7 +464,7 @@ def screenshot(name, path):
@req_connection @req_connection
@wrap_libvirtError @wrap_libvirtError
def migrate(name, host, live=False): def migrate(name, host, live=False):
'''Migrate domain to host''' """ Migrate domain to host. """
flags = libvirt.VIR_MIGRATE_PEER2PEER flags = libvirt.VIR_MIGRATE_PEER2PEER
if live: if live:
flags = flags | libvirt.VIR_MIGRATE_LIVE flags = flags | libvirt.VIR_MIGRATE_LIVE
...@@ -393,4 +474,4 @@ def migrate(name, host, live=False): ...@@ -393,4 +474,4 @@ def migrate(name, host, live=False):
flags=flags, flags=flags,
dname=name, dname=name,
bandwidth=0) bandwidth=0)
#return _parse_info(domain.info()) # return _parse_info(domain.info())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment