Commit 67869460 by Szeberényi Imre

service fix, update fix

parent 1f863982
...@@ -2,67 +2,75 @@ import logging ...@@ -2,67 +2,75 @@ import logging
from logging.handlers import NTEventLogHandler from logging.handlers import NTEventLogHandler
from time import sleep from time import sleep
import os import os
from os.path import join
import servicemanager import servicemanager
import socket import socket
import sys import sys
import winerror
import win32event import win32event
import win32service import win32service
import win32serviceutil import win32serviceutil
logger = logging.getLogger() from utils import setup_logging
fh = NTEventLogHandler( from windows.winutils import getRegistryVal, get_windows_version, servicePostUpdate
"CIRCLE Watchdog", dllname=os.path.dirname(__file__))
if getattr(sys, "frozen", False):
logger = setup_logging(logfile=r"C:\Circle\watchdog.log")
else:
logger = setup_logging()
fh = NTEventLogHandler("CIRCLE Watchdog")
formatter = logging.Formatter( formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s") "%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO') level = getRegistryVal(
r"SYSTEM\\CurrentControlSet\\Services\\CIRCLE-Agent\\Parameters",
"LogLevel",
"INFO"
)
logger.setLevel(level) logger.setLevel(level)
logger.info("%s loaded", __file__) logger.info("%s loaded", __file__)
service_name = "circle-agent"
stopped = False
def watch():
def check_service(service_name):
return win32serviceutil.QueryServiceStatus(service_name)[1] == 4
def start_service():
win32serviceutil.StartService(service_name)
timo_base = 20
timo = timo_base
sleep(6*timo) # boot process may have triggered the agent, so we are patient
while True:
if not check_service(service_name):
logger.info("Service %s is not running.", service_name)
try:
start_service()
timo = timo_base
logger.info("Service %s started.", service_name)
except Exception as e:
timo *= 2
logger.exception("Cant start service %s new timo: %s" % (service_name, timo))
if stopped:
return
sleep(timo)
class AppServerSvc (win32serviceutil.ServiceFramework): class AppServerSvc (win32serviceutil.ServiceFramework):
_svc_name_ = "circle-watchdog" _svc_name_ = "circle-watchdog"
_svc_display_name_ = "CIRCLE Watchdog" _svc_display_name_ = "CIRCLE Watchdog"
_svc_description_ = "Watchdog for CIRCLE Agent"
def __init__(self, args): def __init__(self, args):
win32serviceutil.ServiceFramework.__init__(self, args) win32serviceutil.ServiceFramework.__init__(self, args)
self.hWaitStop = win32event.CreateEvent(None, 0, 0, None) self.hWaitStop = win32event.CreateEvent(None, 0, 0, None)
socket.setdefaulttimeout(60) socket.setdefaulttimeout(60)
self._stopped = False
def watch(self, checked_service):
logger.debug("watch...")
def check_service(checked_service):
return win32serviceutil.QueryServiceStatus(checked_service)[1] == 4
def start_service():
win32serviceutil.StartService(checked_service)
timo_base = 20
timo = timo_base
sleep(6*timo) # boot process may have triggered the agent, so we are patient
while not self._stopped:
logger.debug("checking....(timo: %d", timo)
if not check_service(checked_service):
logger.info("Service %s is not running.", checked_service)
try:
start_service()
timo = timo_base
logger.info("Service %s restarted.", checked_service)
except Exception:
timo = min(timo * 2, 15 * 60) # max 15 perc
logger.exception("Cant start service %s new timo: %s" % (checked_service, timo))
sleep(timo)
def SvcStop(self): def SvcStop(self):
self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING)
win32event.SetEvent(self.hWaitStop) win32event.SetEvent(self.hWaitStop)
global stopped self._stopped = True
stopped = True
logger.info("%s stopped", __file__) logger.info("%s stopped", __file__)
def SvcDoRun(self): def SvcDoRun(self):
...@@ -70,11 +78,27 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -70,11 +78,27 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager.PYS_SERVICE_STARTED, servicemanager.PYS_SERVICE_STARTED,
(self._svc_name_, '')) (self._svc_name_, ''))
logger.info("%s starting", __file__) logger.info("%s starting", __file__)
watch() working_dir = r"C:\circle"
exe = "circle-watchdog.exe"
exe_path = join(working_dir, exe)
logger.debug("hahooo %s %s", self._svc_name_, exe_path)
if servicePostUpdate(self._svc_name_, exe_path):
# Service updated, Restart needed
logger.debug("update....")
self.ReportServiceStatus(
win32service.SERVICE_STOPPED,
win32ExitCode=winerror.ERROR_SERVICE_SPECIFIC_ERROR, # 1066
svcExitCode=int(1)
)
return
self.watch("circle-agent")
# normal stop
self.ReportServiceStatus(win32service.SERVICE_STOPPED)
def main(): def main():
logger.info("Started: %s", sys.argv) logger.info("Started: %s", sys.argv)
if len(sys.argv) == 1: if len(sys.argv) == 1:
# service must be starting... # service must be starting...
# for the sake of debugging etc, we use win32traceutil to see # for the sake of debugging etc, we use win32traceutil to see
......
import logging
from logging.handlers import NTEventLogHandler
import os import os
import servicemanager import servicemanager
import socket import socket
import sys import sys
import winerror
import win32event import win32event
import win32service import win32service
import win32serviceutil import win32serviceutil
import logging
from logging.handlers import NTEventLogHandler
#import agent #import agent, reactor
from agent import main as agent_main, reactor from agent import main as agent_main
from twisted.internet import reactor
logger = logging.getLogger() logger = logging.getLogger()
fh = NTEventLogHandler( fh = NTEventLogHandler("CIRCLE Agent")
"CIRCLE Agent", dllname=os.path.dirname(__file__))
formatter = logging.Formatter( formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s") "%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
#level = os.environ.get('LOGLEVEL', 'DEBUG') #logger.propagate = False
#logger.setLevel('INFO')
logger.setLevel('INFO')
logger.info("%s loaded", __file__) logger.info("%s loaded", __file__)
class AppServerSvc (win32serviceutil.ServiceFramework): class AppServerSvc (win32serviceutil.ServiceFramework):
_svc_name_ = "circle-agent" _svc_name_ = "circle-agent"
_svc_display_name_ = "CIRCLE Agent" _svc_display_name_ = "CIRCLE Agent"
_svc_desciption_ = "CIRCLE cloud contextualization agent" _svc_description_ = "CIRCLE cloud contextualization agent"
def __init__(self, args): def __init__(self, args):
win32serviceutil.ServiceFramework.__init__(self, args) win32serviceutil.ServiceFramework.__init__(self, args)
...@@ -37,7 +38,12 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -37,7 +38,12 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def SvcStop(self): def SvcStop(self):
self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING)
win32event.SetEvent(self.hWaitStop) win32event.SetEvent(self.hWaitStop)
reactor.stop() try:
reactor.stop()
reactor.callFromThread(reactor.stop)
# reactor.callLater(0, reactor.stop)
except Exception:
logger.exception("reactor.stop failed")
logger.info("%s stopped", __file__) logger.info("%s stopped", __file__)
def SvcDoRun(self): def SvcDoRun(self):
...@@ -49,12 +55,25 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -49,12 +55,25 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager.PYS_SERVICE_STARTED, servicemanager.PYS_SERVICE_STARTED,
(self._svc_name_, '')) (self._svc_name_, ''))
break break
except Exception as e: except Exception:
logger.exception("Servicemanager busy?", e) logger.exception("Servicemanager busy?")
cnt -= 1 cnt -= 1
if cnt: if cnt:
logger.info("Starting agent_main") logger.info("Starting agent_main")
agent_main() ret = agent_main()
logger.error("agent_main returned ret=%r type=%s", ret, type(ret).__name__)
logger.info("agent_main finished")
if ret != 0:
# “Service-specific error”
self.ReportServiceStatus(
win32service.SERVICE_STOPPED,
win32ExitCode=winerror.ERROR_SERVICE_SPECIFIC_ERROR, # 1066
svcExitCode=int(ret)
)
return
# normal stop
self.ReportServiceStatus(win32service.SERVICE_STOPPED)
def main(): def main():
......
...@@ -8,16 +8,21 @@ import subprocess ...@@ -8,16 +8,21 @@ import subprocess
import sys import sys
from shutil import rmtree from shutil import rmtree
logging.basicConfig( from logging.handlers import TimedRotatingFileHandler
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] %(module)s.%(funcName)s:%(lineno)d] %(message)s", from pathlib import Path
datefmt="%d/%b/%Y %H:%M:%S", from utils import setup_logging
)
logger = logging.getLogger() if getattr(sys, "frozen", False):
logger = setup_logging(logfile=r"C:\Circle\agent.log")
else:
logger = setup_logging()
level = environ.get('LOGLEVEL', 'INFO') level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level) logger.setLevel(level)
system = platform.system() # noqa system = platform.system() # noqa
logger.debug("system:%s", system) logger.debug("system:%s", system)
win = platform.system() == "Windows"
if len(sys.argv) != 1 and (system == "Linux" or system == "FreeBSD"): # noqa if len(sys.argv) != 1 and (system == "Linux" or system == "FreeBSD"): # noqa
logger.info("Installing agent on %s system", system) logger.info("Installing agent on %s system", system)
...@@ -39,9 +44,17 @@ import uptime ...@@ -39,9 +44,17 @@ import uptime
from inspect import getargs, isfunction from inspect import getargs, isfunction
from utils import SerialLineReceiverBase from utils import SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext if win:
# (relative import error. from windows.winutils import getRegistryVal, get_windows_version
from context import BaseContext, get_context, get_serial # noqa level = getRegistryVal(
r"SYSTEM\\CurrentControlSet\\Services\\CIRCLE-agent\\Parameters",
"LogLevel",
level
)
logger.setLevel(level)
# system = get_windows_version()
from context import get_context, get_serial # noqa
try: try:
# Python 2: "unicode" is built-in # Python 2: "unicode" is built-in
...@@ -86,14 +99,14 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -86,14 +99,14 @@ class SerialLineReceiver(SerialLineReceiverBase):
self.send_command(command='agent_stopped', args={}) self.send_command(command='agent_stopped', args={})
def mayStartNow(self): def mayStartNow(self):
if BaseContext.placed: if Context.placed:
self.mayStartNowId.stop() self.mayStartNowId.stop()
logger.info("Placed") logger.info("Placed")
return return
self.send_startMsg() self.send_startMsg()
def tick(self): def tick(self):
logger.debug("Sending tick") # logger.debug("Sending tick")
try: try:
self.send_status() self.send_status()
except Exception as e: except Exception as e:
...@@ -109,7 +122,7 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -109,7 +122,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
return d return d
def send_startMsg(self): def send_startMsg(self):
logger.debug("Sending start message...") logger.debug("Sending start message: %s %s", Context.get_agent_version(), system)
# Hack for flushing the lower level buffersr # Hack for flushing the lower level buffersr
self.transport.dataBuffer = b"" self.transport.dataBuffer = b""
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite
...@@ -132,10 +145,10 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -132,10 +145,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
"disk": disk_usage, "disk": disk_usage,
"user": {"count": len(psutil.users())}} "user": {"count": len(psutil.users())}}
self.send_response(response='status', args=args) self.send_response(response='status', args=args)
logger.debug("send_status finished") # logger.debug("send_status finished")
def _check_args(self, func, args): def _check_args(self, func, args):
logger.debug("_check_args %s %s" % (func, args)) # logger.debug("_check_args %s %s" % (func, args))
if not isinstance(args, dict): if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a " raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." % "dict for command %s instead of %s." %
...@@ -161,10 +174,10 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -161,10 +174,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
if missing_kwargs: if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % ( raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs))) self._pretty_fun(func), ", ".join(missing_kwargs)))
logger.debug("_check_args finished") # logger.debug("_check_args finished")
def _get_command(self, command, args): def _get_command(self, command, args):
logger.debug("_get_command %s %s" % (command, args)) # logger.debug("_get_command %s" % command)
if not isinstance(command, unicode) or command.startswith('_'): if not isinstance(command, unicode) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command) raise AttributeError(u'Invalid command: %s' % command)
try: try:
...@@ -177,7 +190,7 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -177,7 +190,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
self._pretty_fun(func)) self._pretty_fun(func))
self._check_args(func, args) self._check_args(func, args)
logger.debug("_get_command finished") # logger.debug("_get_command finished")
return func return func
@staticmethod @staticmethod
...@@ -194,9 +207,9 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -194,9 +207,9 @@ class SerialLineReceiver(SerialLineReceiverBase):
return "<%s>" % type(fun).__name__ return "<%s>" % type(fun).__name__
def handle_command(self, command, args): def handle_command(self, command, args):
logger.debug("handle_command %s %s" % (command, args)) logger.debug("handle_command %s" % command)
func = self._get_command(command, args) func = self._get_command(command, args)
logger.debug("Call cmd: %s %s" % (func, args)) # logger.debug("Call cmd: %s" % func)
retval = func(**args) retval = func(**args)
logger.debug("Retval: %s" % retval) logger.debug("Retval: %s" % retval)
self.send_response( self.send_response(
...@@ -207,6 +220,11 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -207,6 +220,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
pass pass
def main(): def main():
if Context.postUpdate():
# Service updated, Restart needed
return 1
# Get proper serial class and port name # Get proper serial class and port name
(serial, port) = get_serial() (serial, port) = get_serial()
logger.info("Opening port %s", port) logger.info("Opening port %s", port)
...@@ -220,7 +238,7 @@ def main(): ...@@ -220,7 +238,7 @@ def main():
logger.debug("Starting reactor.") logger.debug("Starting reactor.")
reactor.run() reactor.run()
logger.debug("Reactor finished.") logger.debug("Reactor finished.")
return Context.exit_code
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -3,15 +3,20 @@ ...@@ -3,15 +3,20 @@
# Should be in autostart and run by the user logged in # Should be in autostart and run by the user logged in
import logging import logging
from os.path import join
from os import environ
from notify import run_client, get_temp_dir
logger = logging.getLogger() logger = logging.getLogger()
fh = logging.FileHandler("agent-client.log") logfile = join(get_temp_dir(), "agent-client.log")
fh = logging.FileHandler(logfile)
formatter = logging.Formatter( formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s") "%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
from notify import run_client
if __name__ == '__main__': if __name__ == '__main__':
run_client() run_client()
...@@ -5,7 +5,21 @@ import logging ...@@ -5,7 +5,21 @@ import logging
import platform import platform
import re import re
logger = logging.getLogger() logger = logging.getLogger(__name__)
# --- compatibility patch for old libs expecting inspect.getargspec (Py<3.11) ---
import inspect
from collections import namedtuple
if not hasattr(inspect, "getargspec"):
ArgSpec = namedtuple("ArgSpec", "args varargs keywords defaults")
def getargspec(func):
fs = inspect.getfullargspec(func)
return ArgSpec(fs.args, fs.varargs, fs.varkw, fs.defaults)
inspect.getargspec = getargspec
# ---------------------------------------------------------------------------
def _get_virtio_device(): def _get_virtio_device():
path = None path = None
...@@ -78,7 +92,12 @@ def get_serial(): ...@@ -78,7 +92,12 @@ def get_serial():
class BaseContext(object): class BaseContext(object):
placed = False # if we reciwed password or net commands placed = False # if we reciwed password or net commands
exit_code = 0
@staticmethod
def postUpdate():
return false
@staticmethod @staticmethod
def change_password(password): def change_password(password):
pass pass
......
...@@ -34,7 +34,7 @@ try: ...@@ -34,7 +34,7 @@ try:
except NameError: except NameError:
unicode = str unicode = str
logger = logging.getLogger() logger = logging.getLogger(__name__)
logger.debug("notify imported") logger.debug("notify imported")
file_name = "vm_renewal.json" file_name = "vm_renewal.json"
win = platform.system() == "Windows" win = platform.system() == "Windows"
...@@ -51,6 +51,8 @@ def parse_arguments(): ...@@ -51,6 +51,8 @@ def parse_arguments():
def get_temp_dir(): def get_temp_dir():
if os.getenv("TMPDIR"): if os.getenv("TMPDIR"):
temp_dir = os.getenv("TMPDIR") temp_dir = os.getenv("TMPDIR")
elif os.getenv("TEMP"):
temp_dir = os.getenv("TEMP")
elif os.getenv("TMP"): elif os.getenv("TMP"):
temp_dir = os.getenv("TMP") temp_dir = os.getenv("TMP")
elif os.path.exists("/tmp"): elif os.path.exists("/tmp"):
...@@ -72,20 +74,25 @@ def wall(text): ...@@ -72,20 +74,25 @@ def wall(text):
process.communicate(input=text)[0] process.communicate(input=text)[0]
def accept(): def accept(url=None):
import datetime import datetime
from tzlocal import get_localzone from tzlocal import get_localzone
from pytz import UTC from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name) if url == None:
if not os.path.isfile(file_path): file_path = os.path.join(get_temp_dir(), file_name)
print("There is no recent notification to accept.") if not os.path.isfile(file_path):
return False print("There is no recent notification to accept.")
return False
# Load the saved url
url = json.load(open(file_path, "r")) # Load the saved url
url = json.load(open(file_path, "r"))
os.remove(file_path)
cj = cookielib.CookieJar() cj = cookielib.CookieJar()
opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(cj)) opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(cj))
msh = None
ret = False
new_local_time = None
try: try:
opener.open(url) # GET to collect cookies opener.open(url) # GET to collect cookies
cookies = cj._cookies_for_request(urllib2.Request(url)) cookies = cj._cookies_for_request(urllib2.Request(url))
...@@ -95,25 +102,23 @@ def accept(): ...@@ -95,25 +102,23 @@ def accept():
b"x-csrftoken": token}) b"x-csrftoken": token})
rsp = opener.open(req) rsp = opener.open(req)
data = json.load(rsp) data = json.load(rsp)
logger.debug("data %r", data)
newtime = data["new_suspend_time"] newtime = data["new_suspend_time"]
# Parse time from JSON (Create UTC Localized Datetime objec) msg = data["message"]
parsed_time = datetime.datetime.strptime( # # Parse time from JSON (Create UTC Localized Datetime objec)
newtime[:-6], "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=UTC) # parsed_time = datetime.datetime.strptime(
# Convert to the machine localization # newtime[:-6], "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=UTC)
new_local_time = parsed_time.astimezone( # # Convert to the machine localization
get_localzone()).strftime("%Y-%m-%d %H:%M:%S") # new_local_time = parsed_time.astimezone(
# get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e: except ValueError as e:
print("Parsing time failed: %s" % e) msg = "Parsing time failed: %s" % e
except Exception as e: except Exception as e:
print(e) msg = "Renewal failed. Please try it manually at %s" % url
print("Renewal failed. Please try it manually at %s" % url)
logger.exception("renew failed") logger.exception("renew failed")
return False
else: else:
print("Renew succeeded. The machine will be " ret = True
"suspended at %s." % new_local_time) return { 'ret': ret, 'msg': msg, 'new_local_time': new_local_time}
os.remove(file_path)
return True
def notify(url): def notify(url):
...@@ -220,6 +225,8 @@ def search_display(): ...@@ -220,6 +225,8 @@ def search_display():
if win: if win:
from twisted.internet import protocol from twisted.internet import protocol
from twisted.protocols import basic from twisted.protocols import basic
from winotify import Notification
from prompt import prompt_yes_no
clients = set() clients = set()
port = 25683 port = 25683
...@@ -260,13 +267,40 @@ if win: ...@@ -260,13 +267,40 @@ if win:
def lineReceived(self, line): def lineReceived(self, line):
logger.debug("received %s %s" % (line, type(line))) # logger.debug("received %s %s" % (line, type(line)))
if not isinstance(line, str): if not isinstance(line, str):
line = line.decode() line = line.decode()
if line.startswith('cifs://'): if line.startswith('cifs://'):
mount_smb(line) mount_smb(line)
else: else:
open_in_browser(line) file_path = os.path.join(get_temp_dir(), file_name)
if file_already_exists(file_path):
os.remove(file_path)
if file_already_exists(file_path):
raise Exception(
"Couldn't create file %s as new" %
file_path)
with open(file_path, "w") as f:
json.dump(line, f)
# open_in_browser(line)
toast = Notification(app_id="CIRCLE Agent", title="VM expiration", msg="VM expiring soon", duration="long")
toast.add_actions(label="renrew", launch=line)
toast.show()
ans = prompt_yes_no(
title="Warning",
message="This VM expiring soon",
yes_label="Renew",
no_label="Cancel",
timeout_seconds=60, )
if ans == "yes":
ret = accept(line)
prompt_yes_no(
title="Info",
message=ret['msg'],
yes_label="OK",
no_label="",
timeout_seconds=10 if ret['ret'] else 60)
class SubFactory(protocol.ReconnectingClientFactory): class SubFactory(protocol.ReconnectingClientFactory):
......
# prompt.py
# Tkinter modal prompt: 1 or 2 buttons, optional timeout.
# Thread-free (Tkinter-safe). Good for PyInstaller.
from __future__ import annotations
from typing import Optional, Literal
Result = Literal["yes", "no", "timeout"]
def prompt_yes_no(
title: str,
message: str,
yes_label: str = "Yes",
no_label: str = "No", # if empty/whitespace => single-button mode
default: Literal["yes", "no"] = "yes",
timeout_seconds: Optional[int] = None,
show_countdown: bool = True,
topmost: bool = True,
wrap_width: int = 420,
) -> Result:
import tkinter as tk
from tkinter import ttk
single_button = (no_label.strip() == "")
# Default result if user closes window (X) or ESC
result_value: Result = "no"
root = tk.Tk()
root.title(title)
root.resizable(False, False)
if topmost:
root.attributes("-topmost", True)
frame = ttk.Frame(root, padding=16)
frame.grid(row=0, column=0)
lbl = ttk.Label(frame, text=message, justify="left", wraplength=wrap_width)
lbl.grid(row=0, column=0, columnspan=2, sticky="w")
countdown_var = tk.StringVar(value="")
countdown_lbl = ttk.Label(frame, textvariable=countdown_var)
countdown_lbl.grid(row=1, column=0, columnspan=2, sticky="w", pady=(8, 0))
after_id = None
cancelled = False
def cancel_timer():
nonlocal after_id
if after_id is not None:
try:
root.after_cancel(after_id)
except Exception:
pass
after_id = None
def finish(value: Result):
nonlocal cancelled, result_value
cancelled = True
result_value = value
cancel_timer()
try:
root.destroy()
except Exception:
pass
def on_yes():
finish("yes")
def on_no():
finish("no")
# Buttons row
btn_row = 2
btn_yes = ttk.Button(frame, text=yes_label, command=on_yes)
btn_yes.grid(row=btn_row, column=0, padx=6, pady=(16, 0))
btn_no = None
if not single_button:
btn_no = ttk.Button(frame, text=no_label, command=on_no)
btn_no.grid(row=btn_row, column=1, padx=6, pady=(16, 0))
else:
# In single-button mode, span across
btn_yes.grid_configure(column=0, columnspan=2)
# Default focus
if default == "yes" or single_button:
btn_yes.focus_set()
else:
if btn_no is not None:
btn_no.focus_set()
# Window close acts like "no" (safe default)
root.protocol("WM_DELETE_WINDOW", on_no)
# Keyboard shortcuts:
# ESC => No
root.bind("<Escape>", lambda _evt: on_no())
# Enter => default (or the only button)
if single_button or default == "yes":
root.bind("<Return>", lambda _evt: on_yes())
else:
root.bind("<Return>", lambda _evt: on_no())
# Center window
root.update_idletasks()
w, h = root.winfo_width(), root.winfo_height()
sw, sh = root.winfo_screenwidth(), root.winfo_screenheight()
root.geometry(f"+{(sw - w)//2}+{(sh - h)//3}")
# Timeout (thread-free)
if timeout_seconds and timeout_seconds > 0:
remaining = int(timeout_seconds)
def tick():
nonlocal remaining, after_id
if cancelled or not root.winfo_exists():
return
if remaining <= 0:
finish("timeout")
return
if show_countdown:
countdown_var.set(f"Auto close in {remaining} seconds…")
remaining -= 1
after_id = root.after(1000, tick)
after_id = root.after(1000, tick)
else:
# If no timeout, hide countdown line (optional)
if not show_countdown:
countdown_var.set("")
root.mainloop()
return result_value
if __name__ == "__main__":
# Demo
r = prompt_yes_no(
title="VM expiring",
message="A VM 5 percen belul suspend lesz.\n\nSzeretned meghosszabbitani?",
yes_label="Renew",
no_label="Cancel",
timeout_seconds=20,
)
print("Result:", r)
r2 = prompt_yes_no(
title="Info",
message="This is a single-button message.",
yes_label="OK",
no_label="",
timeout_seconds=10,
)
print("Result2:", r2)
...@@ -10,4 +10,5 @@ tzlocal ...@@ -10,4 +10,5 @@ tzlocal
pytz pytz
pywin32 pywin32
wmi wmi
winotify
from twisted.protocols.basic import LineReceiver from twisted.protocols.basic import LineReceiver
import sys
import json import json
import logging import logging
from logging.handlers import TimedRotatingFileHandler
import platform import platform
from os import chmod from os import chmod
from shutil import copyfile from shutil import copyfile
...@@ -11,9 +13,60 @@ try: ...@@ -11,9 +13,60 @@ try:
except NameError: except NameError:
unicode = str unicode = str
logger = logging.getLogger() logger = logging.getLogger(__name__)
system = platform.system() system = platform.system()
def setup_logging(logfile=None, backup_count=3):
logger = logging.getLogger()
logger.handlers.clear()
formatter = logging.Formatter(
"[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] "
"%(module)s.%(funcName)s:%(lineno)d] %(message)s",
"%d/%b/%Y %H:%M:%S",
)
if logfile != None:
handler = TimedRotatingFileHandler(
filename=logfile,
when="midnight",
backupCount=backup_count,
encoding="utf-8",
delay=True,
)
else:
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def setup_logging_timed(logfile: str, level: str, backup_count: int):
handler = TimedRotatingFileHandler(
filename=str(log_path),
when="midnight",
interval=1,
backupCount=backup_count,
encoding="utf-8",
utc=False, # use local time (Budapest)
delay=True, # create file only when first log is emitted
)
formatter = logging.Formatter(
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] "
"%(module)s.%(funcName)s:%(lineno)d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
)
handler.setFormatter(formatter)
root = logging.getLogger()
root.setLevel(level)
# Avoid duplicate handlers if setup called multiple times
root.handlers.clear()
root.addHandler(handler)
return root
class SerialLineReceiverBase(LineReceiver, object): class SerialLineReceiverBase(LineReceiver, object):
MAX_LENGTH = 1024*1024*128 MAX_LENGTH = 1024*1024*128
...@@ -26,7 +79,7 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -26,7 +79,7 @@ class SerialLineReceiverBase(LineReceiver, object):
super(SerialLineReceiverBase, self).__init__(*args, **kwargs) super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args): def send_response(self, response, args):
logger.debug("send_response %s %s" % (response, args)) # logger.debug("send_response %s %s" % (response, args))
self.transport.write(json.dumps({'response': response, self.transport.write(json.dumps({'response': response,
'args': args}) + '\r\n') 'args': args}) + '\r\n')
...@@ -41,7 +94,7 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -41,7 +94,7 @@ class SerialLineReceiverBase(LineReceiver, object):
raise NotImplementedError("Subclass must implement abstract method") raise NotImplementedError("Subclass must implement abstract method")
def lineReceived(self, data): def lineReceived(self, data):
logger.debug("lineReceived: %s", data) # logger.debug("lineReceived: %s", data)
if (isinstance(data, unicode)): if (isinstance(data, unicode)):
data = data.strip('\0') data = data.strip('\0')
else: else:
...@@ -53,14 +106,14 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -53,14 +106,14 @@ class SerialLineReceiverBase(LineReceiver, object):
args = {} args = {}
command = data.get('command', None) command = data.get('command', None)
response = data.get('response', None) response = data.get('response', None)
logger.debug('[serial] valid json: %s' % (data, )) # logger.debug('[serial] valid json: %s' % (data, ))
except (ValueError, KeyError) as e: except (ValueError, KeyError) as e:
logger.error('[serial] invalid json: %s (%s)' % (data, e)) # logger.error('[serial] invalid json: %s (%s)' % (data, e))
self.clearLineBuffer() self.clearLineBuffer()
return return
if command is not None and isinstance(command, unicode): if command is not None and isinstance(command, unicode):
logger.debug('received command: %s (%s)' % (command, args)) # logger.debug('received command: %s (%s)' % (command, args[:10]))
try: try:
self.handle_command(command, args) self.handle_command(command, args)
except Exception as e: except Exception as e:
...@@ -93,3 +146,4 @@ def copy_file(src, dst, overw=False, mode=None): ...@@ -93,3 +146,4 @@ def copy_file(src, dst, overw=False, mode=None):
return copyed return copyed
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F watchdog-winservice.py pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-wdog-winservice.py
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-winservice.py pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-winservice.py
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F circle-notify.pyw pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F circle-notify.pyw
\ No newline at end of file \ No newline at end of file
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
working_directory = r"C:\circle" # noqa
from os.path import join from os.path import join
import sys
import logging import logging
import tarfile import tarfile
from io import BytesIO from io import BytesIO
...@@ -19,6 +18,10 @@ from twisted.internet import reactor ...@@ -19,6 +18,10 @@ from twisted.internet import reactor
from .network import change_ip_windows from .network import change_ip_windows
from context import BaseContext from context import BaseContext
from windows.winutils import (
is_frozen_exe, copy_running_exe,
update_service_binpath, servicePostUpdate
)
try: try:
# Python 2: "unicode" is built-in # Python 2: "unicode" is built-in
...@@ -26,14 +29,22 @@ try: ...@@ -26,14 +29,22 @@ try:
except NameError: except NameError:
unicode = str unicode = str
logger = logging.getLogger() logger = logging.getLogger(__name__)
class Context(BaseContext): class Context(BaseContext):
service_name = "CIRCLE-agent"
working_dir = r"C:\circle"
exe = "circle-agent.exe"
@staticmethod @staticmethod
def postUpdate():
exe_path = join(Context.working_dir, Context.exe)
return servicePostUpdate(Context.service_name, exe_path)
@staticmethod
def change_password(password): def change_password(password):
BaseContext.placed = True Context.placed = True
from win32com import adsi from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud') ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo() ads_obj.Getinfo()
...@@ -45,7 +56,7 @@ class Context(BaseContext): ...@@ -45,7 +56,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_ip(interfaces, dns): def change_ip(interfaces, dns):
BaseContext.placed = True Context.placed = True
nameservers = dns.replace(' ', '').split(',') nameservers = dns.replace(' ', '').split(',')
change_ip_windows(interfaces, nameservers) change_ip_windows(interfaces, nameservers)
...@@ -99,19 +110,6 @@ class Context(BaseContext): ...@@ -99,19 +110,6 @@ class Context(BaseContext):
myfile.write(data) myfile.write(data)
@staticmethod @staticmethod
def _update_registry(dir, executable):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS)
with OpenKeyEx(HKEY_LOCAL_MACHINE,
r'SYSTEM\CurrentControlSet\services\circle-agent',
0,
KEY_ALL_ACCESS) as key:
(old_executable, reg_type) = QueryValueEx(key, "ImagePath")
SetValueEx(key, "ImagePath", None, 2, join(dir, executable))
return old_executable
@staticmethod
def update(filename, executable, checksum, uuid): def update(filename, executable, checksum, uuid):
with open(filename, "r") as f: with open(filename, "r") as f:
data = f.read() data = f.read()
...@@ -122,13 +120,14 @@ class Context(BaseContext): ...@@ -122,13 +120,14 @@ class Context(BaseContext):
decoded = BytesIO(b64decode(data)) decoded = BytesIO(b64decode(data))
try: try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz') tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory) tar.extractall(Context.working_dir)
except tarfile.ReadError as e: except tarfile.ReadError as e:
logger.error(e) logger.error(e)
logger.info("Transfer completed!") logger.info("Transfer completed!")
old_exe = Context._update_registry(working_directory, executable) old_exe = update_service_binpath("CIRCLE-agent", join(Context.working_dir, executable))
logger.info('%s Updated', old_exe) logger.info('%s Updated', old_exe)
reactor.stop() Context.exit_code = 1
reactor.callLater(0, reactor.stop)
@staticmethod @staticmethod
def ipaddresses(): def ipaddresses():
...@@ -149,7 +148,7 @@ class Context(BaseContext): ...@@ -149,7 +148,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def get_agent_version(): def get_agent_version():
try: try:
with open(join(working_directory, 'version.txt')) as f: with open(join(Context.working_dir, 'version.txt')) as f:
return f.readline() return f.readline()
except IOError: except IOError:
return None return None
...@@ -20,7 +20,7 @@ from twisted.internet import abstract ...@@ -20,7 +20,7 @@ from twisted.internet import abstract
# sibling imports # sibling imports
import logging import logging
logger = logging.getLogger() logger = logging.getLogger(__name__)
class SerialPort(abstract.FileDescriptor): class SerialPort(abstract.FileDescriptor):
...@@ -68,7 +68,7 @@ class SerialPort(abstract.FileDescriptor): ...@@ -68,7 +68,7 @@ class SerialPort(abstract.FileDescriptor):
self._overlappedRead) self._overlappedRead)
def serialReadEvent(self): def serialReadEvent(self):
logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh)) # logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
try: try:
n = win32file.GetOverlappedResult(self.hComPort, self._overlappedRead, 1) n = win32file.GetOverlappedResult(self.hComPort, self._overlappedRead, 1)
except Exception as e: except Exception as e:
...@@ -89,14 +89,14 @@ class SerialPort(abstract.FileDescriptor): ...@@ -89,14 +89,14 @@ class SerialPort(abstract.FileDescriptor):
data = str.encode(data) data = str.encode(data)
if self.writeInProgress: if self.writeInProgress:
self.outQueue.append(data) self.outQueue.append(data)
logger.debug("added to queue") # logger.debug("added to queue")
else: else:
self.writeInProgress = 1 self.writeInProgress = 1
ret, n = win32file.WriteFile(self.hComPort, data, self._overlappedWrite) ret, n = win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file %s", ret) # logger.debug("Writed to file %s", ret)
def serialWriteEvent(self): def serialWriteEvent(self):
logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh)) # logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
if self._overlappedWrite.Internal < 0 and self._overlappedWrite.InternalHigh == 0 : # DANGER: Not documented variables if self._overlappedWrite.Internal < 0 and self._overlappedWrite.InternalHigh == 0 : # DANGER: Not documented variables
logger.debug(self.connLost()) logger.debug(self.connLost())
self.writeInProgress = 0 self.writeInProgress = 0
......
import os
import sys
import logging
from shutil import copy
from os.path import join, normcase, normpath
from winreg import (
OpenKeyEx, QueryValueEx, SetValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS, KEY_READ
)
logger = logging.getLogger()
def is_frozen_exe() -> bool:
return bool(getattr(sys, "frozen", False))
def update_service_binpath(service_name: str, exe_path: str) -> str:
"""
Update service ImagePath in registry to point to exe_path.
Returns the previous ImagePath string.
"""
with OpenKeyEx(HKEY_LOCAL_MACHINE,
fr"SYSTEM\CurrentControlSet\services\{service_name}",
0,
KEY_ALL_ACCESS) as key:
(old_executable, reg_type) = QueryValueEx(key, "ImagePath")
SetValueEx(key, "ImagePath", None, 2, f'"{exe_path}"')
return old_executable
def copy_running_exe(dest: str) -> bool:
"""
Startup helper:
- If the runnin executable is not
then copy it to dest (overwriting old dest if present),
- Otherwise do nothing.
Returns True if it performed changes, otherwise False.
"""
# Where are we actually running from?
current_exe = sys.executable
# Windows paths are case-insensitive -> compare with normcase
if normcase(current_exe) == normcase(dest):
return False
copy(current_exe, dest)
return True
def servicePostUpdate(service_name, exe_path):
logger.debug("Running exe %s", sys.executable)
if is_frozen_exe() and copy_running_exe(exe_path):
logger.debug("The running agent copyed to %s", exe_path)
old_exe = update_service_binpath(service_name, exe_path)
logger.debug("%s service binpath updated %s -> %s", service_name, old_exe, exe_path)
return True
return False
def getRegistryVal(reg_path: str, name: str, default=None):
"""
Read HKLM\\<reg_path>\\<name> and return its value.
If key or value does not exist, return default.
Example:
getRegistryVal(
r"SYSTEM\\CurrentControlSet\\Services\\circle-agent",
"LogLevel",
"INFO"
)
"""
value=default
try:
with OpenKeyEx(HKEY_LOCAL_MACHINE, reg_path, 0, KEY_READ) as key:
value, _ = QueryValueEx(key, name)
except Exception as e:
logging.debug("Registry read failed %s\\%s: %s",
reg_path, name, e
)
return value
def get_windows_version():
if sys.platform != "win32":
return None
ver = sys.getwindowsversion()
major = ver.major
minor = ver.minor
build = ver.build
# Windows 7
if major == 6 and minor == 1:
return "Windows_7"
# Windows 8 / 8.1
if major == 6 and minor in (2, 3):
return "Windows_8"
# Windows 10 / 11
if major == 10:
# Windows 11 starts at build 22000
if build >= 22000:
return "Windows_11"
else:
return "Windows_10"
return f"Windows_{major}_{minor}_{build})"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment