Commit 67869460 by Szeberényi Imre

service fix, update fix

parent 1f863982
......@@ -2,67 +2,75 @@ import logging
from logging.handlers import NTEventLogHandler
from time import sleep
import os
from os.path import join
import servicemanager
import socket
import sys
import winerror
import win32event
import win32service
import win32serviceutil
logger = logging.getLogger()
fh = NTEventLogHandler(
"CIRCLE Watchdog", dllname=os.path.dirname(__file__))
from utils import setup_logging
from windows.winutils import getRegistryVal, get_windows_version, servicePostUpdate
if getattr(sys, "frozen", False):
logger = setup_logging(logfile=r"C:\Circle\watchdog.log")
else:
logger = setup_logging()
fh = NTEventLogHandler("CIRCLE Watchdog")
formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO')
level = getRegistryVal(
r"SYSTEM\\CurrentControlSet\\Services\\CIRCLE-Agent\\Parameters",
"LogLevel",
"INFO"
)
logger.setLevel(level)
logger.info("%s loaded", __file__)
service_name = "circle-agent"
stopped = False
def watch():
def check_service(service_name):
return win32serviceutil.QueryServiceStatus(service_name)[1] == 4
def start_service():
win32serviceutil.StartService(service_name)
timo_base = 20
timo = timo_base
sleep(6*timo) # boot process may have triggered the agent, so we are patient
while True:
if not check_service(service_name):
logger.info("Service %s is not running.", service_name)
try:
start_service()
timo = timo_base
logger.info("Service %s started.", service_name)
except Exception as e:
timo *= 2
logger.exception("Cant start service %s new timo: %s" % (service_name, timo))
if stopped:
return
sleep(timo)
class AppServerSvc (win32serviceutil.ServiceFramework):
_svc_name_ = "circle-watchdog"
_svc_display_name_ = "CIRCLE Watchdog"
_svc_description_ = "Watchdog for CIRCLE Agent"
def __init__(self, args):
win32serviceutil.ServiceFramework.__init__(self, args)
self.hWaitStop = win32event.CreateEvent(None, 0, 0, None)
socket.setdefaulttimeout(60)
self._stopped = False
def watch(self, checked_service):
logger.debug("watch...")
def check_service(checked_service):
return win32serviceutil.QueryServiceStatus(checked_service)[1] == 4
def start_service():
win32serviceutil.StartService(checked_service)
timo_base = 20
timo = timo_base
sleep(6*timo) # boot process may have triggered the agent, so we are patient
while not self._stopped:
logger.debug("checking....(timo: %d", timo)
if not check_service(checked_service):
logger.info("Service %s is not running.", checked_service)
try:
start_service()
timo = timo_base
logger.info("Service %s restarted.", checked_service)
except Exception:
timo = min(timo * 2, 15 * 60) # max 15 perc
logger.exception("Cant start service %s new timo: %s" % (checked_service, timo))
sleep(timo)
def SvcStop(self):
self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING)
win32event.SetEvent(self.hWaitStop)
global stopped
stopped = True
self._stopped = True
logger.info("%s stopped", __file__)
def SvcDoRun(self):
......@@ -70,11 +78,27 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager.PYS_SERVICE_STARTED,
(self._svc_name_, ''))
logger.info("%s starting", __file__)
watch()
working_dir = r"C:\circle"
exe = "circle-watchdog.exe"
exe_path = join(working_dir, exe)
logger.debug("hahooo %s %s", self._svc_name_, exe_path)
if servicePostUpdate(self._svc_name_, exe_path):
# Service updated, Restart needed
logger.debug("update....")
self.ReportServiceStatus(
win32service.SERVICE_STOPPED,
win32ExitCode=winerror.ERROR_SERVICE_SPECIFIC_ERROR, # 1066
svcExitCode=int(1)
)
return
self.watch("circle-agent")
# normal stop
self.ReportServiceStatus(win32service.SERVICE_STOPPED)
def main():
logger.info("Started: %s", sys.argv)
if len(sys.argv) == 1:
# service must be starting...
# for the sake of debugging etc, we use win32traceutil to see
......
import logging
from logging.handlers import NTEventLogHandler
import os
import servicemanager
import socket
import sys
import winerror
import win32event
import win32service
import win32serviceutil
import logging
from logging.handlers import NTEventLogHandler
#import agent
from agent import main as agent_main, reactor
#import agent, reactor
from agent import main as agent_main
from twisted.internet import reactor
logger = logging.getLogger()
fh = NTEventLogHandler(
"CIRCLE Agent", dllname=os.path.dirname(__file__))
fh = NTEventLogHandler("CIRCLE Agent")
formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
#level = os.environ.get('LOGLEVEL', 'DEBUG')
logger.setLevel('INFO')
#logger.propagate = False
#logger.setLevel('INFO')
logger.info("%s loaded", __file__)
class AppServerSvc (win32serviceutil.ServiceFramework):
_svc_name_ = "circle-agent"
_svc_display_name_ = "CIRCLE Agent"
_svc_desciption_ = "CIRCLE cloud contextualization agent"
_svc_description_ = "CIRCLE cloud contextualization agent"
def __init__(self, args):
win32serviceutil.ServiceFramework.__init__(self, args)
......@@ -37,7 +38,12 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def SvcStop(self):
self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING)
win32event.SetEvent(self.hWaitStop)
reactor.stop()
try:
reactor.stop()
reactor.callFromThread(reactor.stop)
# reactor.callLater(0, reactor.stop)
except Exception:
logger.exception("reactor.stop failed")
logger.info("%s stopped", __file__)
def SvcDoRun(self):
......@@ -49,12 +55,25 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager.PYS_SERVICE_STARTED,
(self._svc_name_, ''))
break
except Exception as e:
logger.exception("Servicemanager busy?", e)
except Exception:
logger.exception("Servicemanager busy?")
cnt -= 1
if cnt:
logger.info("Starting agent_main")
agent_main()
ret = agent_main()
logger.error("agent_main returned ret=%r type=%s", ret, type(ret).__name__)
logger.info("agent_main finished")
if ret != 0:
# “Service-specific error”
self.ReportServiceStatus(
win32service.SERVICE_STOPPED,
win32ExitCode=winerror.ERROR_SERVICE_SPECIFIC_ERROR, # 1066
svcExitCode=int(ret)
)
return
# normal stop
self.ReportServiceStatus(win32service.SERVICE_STOPPED)
def main():
......
......@@ -8,16 +8,21 @@ import subprocess
import sys
from shutil import rmtree
logging.basicConfig(
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] %(module)s.%(funcName)s:%(lineno)d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
)
logger = logging.getLogger()
from logging.handlers import TimedRotatingFileHandler
from pathlib import Path
from utils import setup_logging
if getattr(sys, "frozen", False):
logger = setup_logging(logfile=r"C:\Circle\agent.log")
else:
logger = setup_logging()
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
system = platform.system() # noqa
logger.debug("system:%s", system)
win = platform.system() == "Windows"
if len(sys.argv) != 1 and (system == "Linux" or system == "FreeBSD"): # noqa
logger.info("Installing agent on %s system", system)
......@@ -39,9 +44,17 @@ import uptime
from inspect import getargs, isfunction
from utils import SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext
# (relative import error.
from context import BaseContext, get_context, get_serial # noqa
if win:
from windows.winutils import getRegistryVal, get_windows_version
level = getRegistryVal(
r"SYSTEM\\CurrentControlSet\\Services\\CIRCLE-agent\\Parameters",
"LogLevel",
level
)
logger.setLevel(level)
# system = get_windows_version()
from context import get_context, get_serial # noqa
try:
# Python 2: "unicode" is built-in
......@@ -86,14 +99,14 @@ class SerialLineReceiver(SerialLineReceiverBase):
self.send_command(command='agent_stopped', args={})
def mayStartNow(self):
if BaseContext.placed:
if Context.placed:
self.mayStartNowId.stop()
logger.info("Placed")
return
self.send_startMsg()
def tick(self):
logger.debug("Sending tick")
# logger.debug("Sending tick")
try:
self.send_status()
except Exception as e:
......@@ -109,7 +122,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
return d
def send_startMsg(self):
logger.debug("Sending start message...")
logger.debug("Sending start message: %s %s", Context.get_agent_version(), system)
# Hack for flushing the lower level buffersr
self.transport.dataBuffer = b""
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite
......@@ -132,10 +145,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
"disk": disk_usage,
"user": {"count": len(psutil.users())}}
self.send_response(response='status', args=args)
logger.debug("send_status finished")
# logger.debug("send_status finished")
def _check_args(self, func, args):
logger.debug("_check_args %s %s" % (func, args))
# logger.debug("_check_args %s %s" % (func, args))
if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." %
......@@ -161,10 +174,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs)))
logger.debug("_check_args finished")
# logger.debug("_check_args finished")
def _get_command(self, command, args):
logger.debug("_get_command %s %s" % (command, args))
# logger.debug("_get_command %s" % command)
if not isinstance(command, unicode) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command)
try:
......@@ -177,7 +190,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
self._pretty_fun(func))
self._check_args(func, args)
logger.debug("_get_command finished")
# logger.debug("_get_command finished")
return func
@staticmethod
......@@ -194,9 +207,9 @@ class SerialLineReceiver(SerialLineReceiverBase):
return "<%s>" % type(fun).__name__
def handle_command(self, command, args):
logger.debug("handle_command %s %s" % (command, args))
logger.debug("handle_command %s" % command)
func = self._get_command(command, args)
logger.debug("Call cmd: %s %s" % (func, args))
# logger.debug("Call cmd: %s" % func)
retval = func(**args)
logger.debug("Retval: %s" % retval)
self.send_response(
......@@ -207,6 +220,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
pass
def main():
if Context.postUpdate():
# Service updated, Restart needed
return 1
# Get proper serial class and port name
(serial, port) = get_serial()
logger.info("Opening port %s", port)
......@@ -220,7 +238,7 @@ def main():
logger.debug("Starting reactor.")
reactor.run()
logger.debug("Reactor finished.")
return Context.exit_code
if __name__ == '__main__':
main()
......@@ -3,15 +3,20 @@
# Should be in autostart and run by the user logged in
import logging
from os.path import join
from os import environ
from notify import run_client, get_temp_dir
logger = logging.getLogger()
fh = logging.FileHandler("agent-client.log")
logfile = join(get_temp_dir(), "agent-client.log")
fh = logging.FileHandler(logfile)
formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
level = environ.get('LOGLEVEL', 'INFO')
logger.setLevel(level)
from notify import run_client
if __name__ == '__main__':
run_client()
......@@ -5,7 +5,21 @@ import logging
import platform
import re
logger = logging.getLogger()
logger = logging.getLogger(__name__)
# --- compatibility patch for old libs expecting inspect.getargspec (Py<3.11) ---
import inspect
from collections import namedtuple
if not hasattr(inspect, "getargspec"):
ArgSpec = namedtuple("ArgSpec", "args varargs keywords defaults")
def getargspec(func):
fs = inspect.getfullargspec(func)
return ArgSpec(fs.args, fs.varargs, fs.varkw, fs.defaults)
inspect.getargspec = getargspec
# ---------------------------------------------------------------------------
def _get_virtio_device():
path = None
......@@ -78,7 +92,12 @@ def get_serial():
class BaseContext(object):
placed = False # if we reciwed password or net commands
exit_code = 0
@staticmethod
def postUpdate():
return false
@staticmethod
def change_password(password):
pass
......
......@@ -34,7 +34,7 @@ try:
except NameError:
unicode = str
logger = logging.getLogger()
logger = logging.getLogger(__name__)
logger.debug("notify imported")
file_name = "vm_renewal.json"
win = platform.system() == "Windows"
......@@ -51,6 +51,8 @@ def parse_arguments():
def get_temp_dir():
if os.getenv("TMPDIR"):
temp_dir = os.getenv("TMPDIR")
elif os.getenv("TEMP"):
temp_dir = os.getenv("TEMP")
elif os.getenv("TMP"):
temp_dir = os.getenv("TMP")
elif os.path.exists("/tmp"):
......@@ -72,20 +74,25 @@ def wall(text):
process.communicate(input=text)[0]
def accept():
def accept(url=None):
import datetime
from tzlocal import get_localzone
from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path):
print("There is no recent notification to accept.")
return False
# Load the saved url
url = json.load(open(file_path, "r"))
if url == None:
file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path):
print("There is no recent notification to accept.")
return False
# Load the saved url
url = json.load(open(file_path, "r"))
os.remove(file_path)
cj = cookielib.CookieJar()
opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(cj))
msh = None
ret = False
new_local_time = None
try:
opener.open(url) # GET to collect cookies
cookies = cj._cookies_for_request(urllib2.Request(url))
......@@ -95,25 +102,23 @@ def accept():
b"x-csrftoken": token})
rsp = opener.open(req)
data = json.load(rsp)
logger.debug("data %r", data)
newtime = data["new_suspend_time"]
# Parse time from JSON (Create UTC Localized Datetime objec)
parsed_time = datetime.datetime.strptime(
newtime[:-6], "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=UTC)
# Convert to the machine localization
new_local_time = parsed_time.astimezone(
get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
msg = data["message"]
# # Parse time from JSON (Create UTC Localized Datetime objec)
# parsed_time = datetime.datetime.strptime(
# newtime[:-6], "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=UTC)
# # Convert to the machine localization
# new_local_time = parsed_time.astimezone(
# get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e:
print("Parsing time failed: %s" % e)
msg = "Parsing time failed: %s" % e
except Exception as e:
print(e)
print("Renewal failed. Please try it manually at %s" % url)
msg = "Renewal failed. Please try it manually at %s" % url
logger.exception("renew failed")
return False
else:
print("Renew succeeded. The machine will be "
"suspended at %s." % new_local_time)
os.remove(file_path)
return True
ret = True
return { 'ret': ret, 'msg': msg, 'new_local_time': new_local_time}
def notify(url):
......@@ -220,6 +225,8 @@ def search_display():
if win:
from twisted.internet import protocol
from twisted.protocols import basic
from winotify import Notification
from prompt import prompt_yes_no
clients = set()
port = 25683
......@@ -260,13 +267,40 @@ if win:
def lineReceived(self, line):
logger.debug("received %s %s" % (line, type(line)))
# logger.debug("received %s %s" % (line, type(line)))
if not isinstance(line, str):
line = line.decode()
if line.startswith('cifs://'):
mount_smb(line)
else:
open_in_browser(line)
file_path = os.path.join(get_temp_dir(), file_name)
if file_already_exists(file_path):
os.remove(file_path)
if file_already_exists(file_path):
raise Exception(
"Couldn't create file %s as new" %
file_path)
with open(file_path, "w") as f:
json.dump(line, f)
# open_in_browser(line)
toast = Notification(app_id="CIRCLE Agent", title="VM expiration", msg="VM expiring soon", duration="long")
toast.add_actions(label="renrew", launch=line)
toast.show()
ans = prompt_yes_no(
title="Warning",
message="This VM expiring soon",
yes_label="Renew",
no_label="Cancel",
timeout_seconds=60, )
if ans == "yes":
ret = accept(line)
prompt_yes_no(
title="Info",
message=ret['msg'],
yes_label="OK",
no_label="",
timeout_seconds=10 if ret['ret'] else 60)
class SubFactory(protocol.ReconnectingClientFactory):
......
# prompt.py
# Tkinter modal prompt: 1 or 2 buttons, optional timeout.
# Thread-free (Tkinter-safe). Good for PyInstaller.
from __future__ import annotations
from typing import Optional, Literal
Result = Literal["yes", "no", "timeout"]
def prompt_yes_no(
title: str,
message: str,
yes_label: str = "Yes",
no_label: str = "No", # if empty/whitespace => single-button mode
default: Literal["yes", "no"] = "yes",
timeout_seconds: Optional[int] = None,
show_countdown: bool = True,
topmost: bool = True,
wrap_width: int = 420,
) -> Result:
import tkinter as tk
from tkinter import ttk
single_button = (no_label.strip() == "")
# Default result if user closes window (X) or ESC
result_value: Result = "no"
root = tk.Tk()
root.title(title)
root.resizable(False, False)
if topmost:
root.attributes("-topmost", True)
frame = ttk.Frame(root, padding=16)
frame.grid(row=0, column=0)
lbl = ttk.Label(frame, text=message, justify="left", wraplength=wrap_width)
lbl.grid(row=0, column=0, columnspan=2, sticky="w")
countdown_var = tk.StringVar(value="")
countdown_lbl = ttk.Label(frame, textvariable=countdown_var)
countdown_lbl.grid(row=1, column=0, columnspan=2, sticky="w", pady=(8, 0))
after_id = None
cancelled = False
def cancel_timer():
nonlocal after_id
if after_id is not None:
try:
root.after_cancel(after_id)
except Exception:
pass
after_id = None
def finish(value: Result):
nonlocal cancelled, result_value
cancelled = True
result_value = value
cancel_timer()
try:
root.destroy()
except Exception:
pass
def on_yes():
finish("yes")
def on_no():
finish("no")
# Buttons row
btn_row = 2
btn_yes = ttk.Button(frame, text=yes_label, command=on_yes)
btn_yes.grid(row=btn_row, column=0, padx=6, pady=(16, 0))
btn_no = None
if not single_button:
btn_no = ttk.Button(frame, text=no_label, command=on_no)
btn_no.grid(row=btn_row, column=1, padx=6, pady=(16, 0))
else:
# In single-button mode, span across
btn_yes.grid_configure(column=0, columnspan=2)
# Default focus
if default == "yes" or single_button:
btn_yes.focus_set()
else:
if btn_no is not None:
btn_no.focus_set()
# Window close acts like "no" (safe default)
root.protocol("WM_DELETE_WINDOW", on_no)
# Keyboard shortcuts:
# ESC => No
root.bind("<Escape>", lambda _evt: on_no())
# Enter => default (or the only button)
if single_button or default == "yes":
root.bind("<Return>", lambda _evt: on_yes())
else:
root.bind("<Return>", lambda _evt: on_no())
# Center window
root.update_idletasks()
w, h = root.winfo_width(), root.winfo_height()
sw, sh = root.winfo_screenwidth(), root.winfo_screenheight()
root.geometry(f"+{(sw - w)//2}+{(sh - h)//3}")
# Timeout (thread-free)
if timeout_seconds and timeout_seconds > 0:
remaining = int(timeout_seconds)
def tick():
nonlocal remaining, after_id
if cancelled or not root.winfo_exists():
return
if remaining <= 0:
finish("timeout")
return
if show_countdown:
countdown_var.set(f"Auto close in {remaining} seconds…")
remaining -= 1
after_id = root.after(1000, tick)
after_id = root.after(1000, tick)
else:
# If no timeout, hide countdown line (optional)
if not show_countdown:
countdown_var.set("")
root.mainloop()
return result_value
if __name__ == "__main__":
# Demo
r = prompt_yes_no(
title="VM expiring",
message="A VM 5 percen belul suspend lesz.\n\nSzeretned meghosszabbitani?",
yes_label="Renew",
no_label="Cancel",
timeout_seconds=20,
)
print("Result:", r)
r2 = prompt_yes_no(
title="Info",
message="This is a single-button message.",
yes_label="OK",
no_label="",
timeout_seconds=10,
)
print("Result2:", r2)
......@@ -10,4 +10,5 @@ tzlocal
pytz
pywin32
wmi
winotify
from twisted.protocols.basic import LineReceiver
import sys
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import platform
from os import chmod
from shutil import copyfile
......@@ -11,9 +13,60 @@ try:
except NameError:
unicode = str
logger = logging.getLogger()
logger = logging.getLogger(__name__)
system = platform.system()
def setup_logging(logfile=None, backup_count=3):
logger = logging.getLogger()
logger.handlers.clear()
formatter = logging.Formatter(
"[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] "
"%(module)s.%(funcName)s:%(lineno)d] %(message)s",
"%d/%b/%Y %H:%M:%S",
)
if logfile != None:
handler = TimedRotatingFileHandler(
filename=logfile,
when="midnight",
backupCount=backup_count,
encoding="utf-8",
delay=True,
)
else:
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def setup_logging_timed(logfile: str, level: str, backup_count: int):
handler = TimedRotatingFileHandler(
filename=str(log_path),
when="midnight",
interval=1,
backupCount=backup_count,
encoding="utf-8",
utc=False, # use local time (Budapest)
delay=True, # create file only when first log is emitted
)
formatter = logging.Formatter(
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] "
"%(module)s.%(funcName)s:%(lineno)d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
)
handler.setFormatter(formatter)
root = logging.getLogger()
root.setLevel(level)
# Avoid duplicate handlers if setup called multiple times
root.handlers.clear()
root.addHandler(handler)
return root
class SerialLineReceiverBase(LineReceiver, object):
MAX_LENGTH = 1024*1024*128
......@@ -26,7 +79,7 @@ class SerialLineReceiverBase(LineReceiver, object):
super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args):
logger.debug("send_response %s %s" % (response, args))
# logger.debug("send_response %s %s" % (response, args))
self.transport.write(json.dumps({'response': response,
'args': args}) + '\r\n')
......@@ -41,7 +94,7 @@ class SerialLineReceiverBase(LineReceiver, object):
raise NotImplementedError("Subclass must implement abstract method")
def lineReceived(self, data):
logger.debug("lineReceived: %s", data)
# logger.debug("lineReceived: %s", data)
if (isinstance(data, unicode)):
data = data.strip('\0')
else:
......@@ -53,14 +106,14 @@ class SerialLineReceiverBase(LineReceiver, object):
args = {}
command = data.get('command', None)
response = data.get('response', None)
logger.debug('[serial] valid json: %s' % (data, ))
# logger.debug('[serial] valid json: %s' % (data, ))
except (ValueError, KeyError) as e:
logger.error('[serial] invalid json: %s (%s)' % (data, e))
# logger.error('[serial] invalid json: %s (%s)' % (data, e))
self.clearLineBuffer()
return
if command is not None and isinstance(command, unicode):
logger.debug('received command: %s (%s)' % (command, args))
# logger.debug('received command: %s (%s)' % (command, args[:10]))
try:
self.handle_command(command, args)
except Exception as e:
......@@ -93,3 +146,4 @@ def copy_file(src, dst, overw=False, mode=None):
return copyed
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F watchdog-winservice.py
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-winservice.py
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F circle-notify.pyw
\ No newline at end of file
pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-wdog-winservice.py
pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-winservice.py
pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F circle-notify.pyw
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
working_directory = r"C:\circle" # noqa
from os.path import join
import sys
import logging
import tarfile
from io import BytesIO
......@@ -19,6 +18,10 @@ from twisted.internet import reactor
from .network import change_ip_windows
from context import BaseContext
from windows.winutils import (
is_frozen_exe, copy_running_exe,
update_service_binpath, servicePostUpdate
)
try:
# Python 2: "unicode" is built-in
......@@ -26,14 +29,22 @@ try:
except NameError:
unicode = str
logger = logging.getLogger()
logger = logging.getLogger(__name__)
class Context(BaseContext):
service_name = "CIRCLE-agent"
working_dir = r"C:\circle"
exe = "circle-agent.exe"
@staticmethod
def postUpdate():
exe_path = join(Context.working_dir, Context.exe)
return servicePostUpdate(Context.service_name, exe_path)
@staticmethod
def change_password(password):
BaseContext.placed = True
Context.placed = True
from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo()
......@@ -45,7 +56,7 @@ class Context(BaseContext):
@staticmethod
def change_ip(interfaces, dns):
BaseContext.placed = True
Context.placed = True
nameservers = dns.replace(' ', '').split(',')
change_ip_windows(interfaces, nameservers)
......@@ -99,19 +110,6 @@ class Context(BaseContext):
myfile.write(data)
@staticmethod
def _update_registry(dir, executable):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from winreg import (OpenKeyEx, SetValueEx, QueryValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS)
with OpenKeyEx(HKEY_LOCAL_MACHINE,
r'SYSTEM\CurrentControlSet\services\circle-agent',
0,
KEY_ALL_ACCESS) as key:
(old_executable, reg_type) = QueryValueEx(key, "ImagePath")
SetValueEx(key, "ImagePath", None, 2, join(dir, executable))
return old_executable
@staticmethod
def update(filename, executable, checksum, uuid):
with open(filename, "r") as f:
data = f.read()
......@@ -122,13 +120,14 @@ class Context(BaseContext):
decoded = BytesIO(b64decode(data))
try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory)
tar.extractall(Context.working_dir)
except tarfile.ReadError as e:
logger.error(e)
logger.info("Transfer completed!")
old_exe = Context._update_registry(working_directory, executable)
old_exe = update_service_binpath("CIRCLE-agent", join(Context.working_dir, executable))
logger.info('%s Updated', old_exe)
reactor.stop()
Context.exit_code = 1
reactor.callLater(0, reactor.stop)
@staticmethod
def ipaddresses():
......@@ -149,7 +148,7 @@ class Context(BaseContext):
@staticmethod
def get_agent_version():
try:
with open(join(working_directory, 'version.txt')) as f:
with open(join(Context.working_dir, 'version.txt')) as f:
return f.readline()
except IOError:
return None
......@@ -20,7 +20,7 @@ from twisted.internet import abstract
# sibling imports
import logging
logger = logging.getLogger()
logger = logging.getLogger(__name__)
class SerialPort(abstract.FileDescriptor):
......@@ -68,7 +68,7 @@ class SerialPort(abstract.FileDescriptor):
self._overlappedRead)
def serialReadEvent(self):
logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
# logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
try:
n = win32file.GetOverlappedResult(self.hComPort, self._overlappedRead, 1)
except Exception as e:
......@@ -89,14 +89,14 @@ class SerialPort(abstract.FileDescriptor):
data = str.encode(data)
if self.writeInProgress:
self.outQueue.append(data)
logger.debug("added to queue")
# logger.debug("added to queue")
else:
self.writeInProgress = 1
ret, n = win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file %s", ret)
# logger.debug("Writed to file %s", ret)
def serialWriteEvent(self):
logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
# logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
if self._overlappedWrite.Internal < 0 and self._overlappedWrite.InternalHigh == 0 : # DANGER: Not documented variables
logger.debug(self.connLost())
self.writeInProgress = 0
......
import os
import sys
import logging
from shutil import copy
from os.path import join, normcase, normpath
from winreg import (
OpenKeyEx, QueryValueEx, SetValueEx,
HKEY_LOCAL_MACHINE, KEY_ALL_ACCESS, KEY_READ
)
logger = logging.getLogger()
def is_frozen_exe() -> bool:
return bool(getattr(sys, "frozen", False))
def update_service_binpath(service_name: str, exe_path: str) -> str:
"""
Update service ImagePath in registry to point to exe_path.
Returns the previous ImagePath string.
"""
with OpenKeyEx(HKEY_LOCAL_MACHINE,
fr"SYSTEM\CurrentControlSet\services\{service_name}",
0,
KEY_ALL_ACCESS) as key:
(old_executable, reg_type) = QueryValueEx(key, "ImagePath")
SetValueEx(key, "ImagePath", None, 2, f'"{exe_path}"')
return old_executable
def copy_running_exe(dest: str) -> bool:
"""
Startup helper:
- If the runnin executable is not
then copy it to dest (overwriting old dest if present),
- Otherwise do nothing.
Returns True if it performed changes, otherwise False.
"""
# Where are we actually running from?
current_exe = sys.executable
# Windows paths are case-insensitive -> compare with normcase
if normcase(current_exe) == normcase(dest):
return False
copy(current_exe, dest)
return True
def servicePostUpdate(service_name, exe_path):
logger.debug("Running exe %s", sys.executable)
if is_frozen_exe() and copy_running_exe(exe_path):
logger.debug("The running agent copyed to %s", exe_path)
old_exe = update_service_binpath(service_name, exe_path)
logger.debug("%s service binpath updated %s -> %s", service_name, old_exe, exe_path)
return True
return False
def getRegistryVal(reg_path: str, name: str, default=None):
"""
Read HKLM\\<reg_path>\\<name> and return its value.
If key or value does not exist, return default.
Example:
getRegistryVal(
r"SYSTEM\\CurrentControlSet\\Services\\circle-agent",
"LogLevel",
"INFO"
)
"""
value=default
try:
with OpenKeyEx(HKEY_LOCAL_MACHINE, reg_path, 0, KEY_READ) as key:
value, _ = QueryValueEx(key, name)
except Exception as e:
logging.debug("Registry read failed %s\\%s: %s",
reg_path, name, e
)
return value
def get_windows_version():
if sys.platform != "win32":
return None
ver = sys.getwindowsversion()
major = ver.major
minor = ver.minor
build = ver.build
# Windows 7
if major == 6 and minor == 1:
return "Windows_7"
# Windows 8 / 8.1
if major == 6 and minor in (2, 3):
return "Windows_8"
# Windows 10 / 11
if major == 10:
# Windows 11 starts at build 22000
if build >= 22000:
return "Windows_11"
else:
return "Windows_10"
return f"Windows_{major}_{minor}_{build})"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment