Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
CIRCLE3
/
agent
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Members
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit
67869460
authored
Jan 14, 2026
by
Szeberényi Imre
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
service fix, update fix
parent
1f863982
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
572 additions
and
124 deletions
+572
-124
__init__.py
+0
-0
agent-wdog-winservice.py
+56
-32
agent-winservice.py
+32
-13
agent.py
+37
-19
circle-notify.pyw
+8
-3
context.py
+20
-1
notify.py
+53
-19
prompt.py
+167
-0
requirements/windows.txt
+1
-0
utils.py
+60
-6
win_build.bat
+4
-4
windows/_win32context.py
+21
-22
windows/win32virtio.py
+5
-5
windows/winutils.py
+108
-0
No files found.
__init__.py
deleted
100644 → 0
View file @
1f863982
agent-wdog-winservice.py
View file @
67869460
...
@@ -2,67 +2,75 @@ import logging
...
@@ -2,67 +2,75 @@ import logging
from
logging.handlers
import
NTEventLogHandler
from
logging.handlers
import
NTEventLogHandler
from
time
import
sleep
from
time
import
sleep
import
os
import
os
from
os.path
import
join
import
servicemanager
import
servicemanager
import
socket
import
socket
import
sys
import
sys
import
winerror
import
win32event
import
win32event
import
win32service
import
win32service
import
win32serviceutil
import
win32serviceutil
logger
=
logging
.
getLogger
()
from
utils
import
setup_logging
fh
=
NTEventLogHandler
(
from
windows.winutils
import
getRegistryVal
,
get_windows_version
,
servicePostUpdate
"CIRCLE Watchdog"
,
dllname
=
os
.
path
.
dirname
(
__file__
))
if
getattr
(
sys
,
"frozen"
,
False
):
logger
=
setup_logging
(
logfile
=
r"C:\Circle\watchdog.log"
)
else
:
logger
=
setup_logging
()
fh
=
NTEventLogHandler
(
"CIRCLE Watchdog"
)
formatter
=
logging
.
Formatter
(
formatter
=
logging
.
Formatter
(
"
%(asctime)
s -
%(name)
s [
%(levelname)
s]
%(message)
s"
)
"
%(asctime)
s -
%(name)
s [
%(levelname)
s]
%(message)
s"
)
fh
.
setFormatter
(
formatter
)
fh
.
setFormatter
(
formatter
)
logger
.
addHandler
(
fh
)
logger
.
addHandler
(
fh
)
level
=
os
.
environ
.
get
(
'LOGLEVEL'
,
'INFO'
)
level
=
getRegistryVal
(
r"SYSTEM\\CurrentControlSet\\Services\\CIRCLE-Agent\\Parameters"
,
"LogLevel"
,
"INFO"
)
logger
.
setLevel
(
level
)
logger
.
setLevel
(
level
)
logger
.
info
(
"
%
s loaded"
,
__file__
)
logger
.
info
(
"
%
s loaded"
,
__file__
)
service_name
=
"circle-agent"
stopped
=
False
class
AppServerSvc
(
win32serviceutil
.
ServiceFramework
):
_svc_name_
=
"circle-watchdog"
_svc_display_name_
=
"CIRCLE Watchdog"
_svc_description_
=
"Watchdog for CIRCLE Agent"
def
__init__
(
self
,
args
):
win32serviceutil
.
ServiceFramework
.
__init__
(
self
,
args
)
self
.
hWaitStop
=
win32event
.
CreateEvent
(
None
,
0
,
0
,
None
)
socket
.
setdefaulttimeout
(
60
)
self
.
_stopped
=
False
def
watch
():
def
watch
(
self
,
checked_service
):
def
check_service
(
service_name
):
logger
.
debug
(
"watch..."
)
return
win32serviceutil
.
QueryServiceStatus
(
service_name
)[
1
]
==
4
def
check_service
(
checked_service
):
return
win32serviceutil
.
QueryServiceStatus
(
checked_service
)[
1
]
==
4
def
start_service
():
def
start_service
():
win32serviceutil
.
StartService
(
service_nam
e
)
win32serviceutil
.
StartService
(
checked_servic
e
)
timo_base
=
20
timo_base
=
20
timo
=
timo_base
timo
=
timo_base
sleep
(
6
*
timo
)
# boot process may have triggered the agent, so we are patient
sleep
(
6
*
timo
)
# boot process may have triggered the agent, so we are patient
while
True
:
while
not
self
.
_stopped
:
if
not
check_service
(
service_name
):
logger
.
debug
(
"checking....(timo:
%
d"
,
timo
)
logger
.
info
(
"Service
%
s is not running."
,
service_name
)
if
not
check_service
(
checked_service
):
logger
.
info
(
"Service
%
s is not running."
,
checked_service
)
try
:
try
:
start_service
()
start_service
()
timo
=
timo_base
timo
=
timo_base
logger
.
info
(
"Service
%
s started."
,
service_name
)
logger
.
info
(
"Service
%
s restarted."
,
checked_service
)
except
Exception
as
e
:
except
Exception
:
timo
*=
2
timo
=
min
(
timo
*
2
,
15
*
60
)
# max 15 perc
logger
.
exception
(
"Cant start service
%
s new timo:
%
s"
%
(
service_name
,
timo
))
logger
.
exception
(
"Cant start service
%
s new timo:
%
s"
%
(
checked_service
,
timo
))
if
stopped
:
return
sleep
(
timo
)
sleep
(
timo
)
class
AppServerSvc
(
win32serviceutil
.
ServiceFramework
):
_svc_name_
=
"circle-watchdog"
_svc_display_name_
=
"CIRCLE Watchdog"
def
__init__
(
self
,
args
):
win32serviceutil
.
ServiceFramework
.
__init__
(
self
,
args
)
self
.
hWaitStop
=
win32event
.
CreateEvent
(
None
,
0
,
0
,
None
)
socket
.
setdefaulttimeout
(
60
)
def
SvcStop
(
self
):
def
SvcStop
(
self
):
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOP_PENDING
)
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOP_PENDING
)
win32event
.
SetEvent
(
self
.
hWaitStop
)
win32event
.
SetEvent
(
self
.
hWaitStop
)
global
stopped
self
.
_stopped
=
True
stopped
=
True
logger
.
info
(
"
%
s stopped"
,
__file__
)
logger
.
info
(
"
%
s stopped"
,
__file__
)
def
SvcDoRun
(
self
):
def
SvcDoRun
(
self
):
...
@@ -70,11 +78,27 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
...
@@ -70,11 +78,27 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager
.
PYS_SERVICE_STARTED
,
servicemanager
.
PYS_SERVICE_STARTED
,
(
self
.
_svc_name_
,
''
))
(
self
.
_svc_name_
,
''
))
logger
.
info
(
"
%
s starting"
,
__file__
)
logger
.
info
(
"
%
s starting"
,
__file__
)
watch
()
working_dir
=
r"C:\circle"
exe
=
"circle-watchdog.exe"
exe_path
=
join
(
working_dir
,
exe
)
logger
.
debug
(
"hahooo
%
s
%
s"
,
self
.
_svc_name_
,
exe_path
)
if
servicePostUpdate
(
self
.
_svc_name_
,
exe_path
):
# Service updated, Restart needed
logger
.
debug
(
"update...."
)
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOPPED
,
win32ExitCode
=
winerror
.
ERROR_SERVICE_SPECIFIC_ERROR
,
# 1066
svcExitCode
=
int
(
1
)
)
return
self
.
watch
(
"circle-agent"
)
# normal stop
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOPPED
)
def
main
():
def
main
():
logger
.
info
(
"Started:
%
s"
,
sys
.
argv
)
logger
.
info
(
"Started:
%
s"
,
sys
.
argv
)
if
len
(
sys
.
argv
)
==
1
:
if
len
(
sys
.
argv
)
==
1
:
# service must be starting...
# service must be starting...
# for the sake of debugging etc, we use win32traceutil to see
# for the sake of debugging etc, we use win32traceutil to see
...
...
agent-winservice.py
View file @
67869460
import
logging
from
logging.handlers
import
NTEventLogHandler
import
os
import
os
import
servicemanager
import
servicemanager
import
socket
import
socket
import
sys
import
sys
import
winerror
import
win32event
import
win32event
import
win32service
import
win32service
import
win32serviceutil
import
win32serviceutil
import
logging
from
logging.handlers
import
NTEventLogHandler
#import agent
#import agent, reactor
from
agent
import
main
as
agent_main
,
reactor
from
agent
import
main
as
agent_main
from
twisted.internet
import
reactor
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
fh
=
NTEventLogHandler
(
fh
=
NTEventLogHandler
(
"CIRCLE Agent"
)
"CIRCLE Agent"
,
dllname
=
os
.
path
.
dirname
(
__file__
))
formatter
=
logging
.
Formatter
(
formatter
=
logging
.
Formatter
(
"
%(asctime)
s -
%(name)
s [
%(levelname)
s]
%(message)
s"
)
"
%(asctime)
s -
%(name)
s [
%(levelname)
s]
%(message)
s"
)
fh
.
setFormatter
(
formatter
)
fh
.
setFormatter
(
formatter
)
logger
.
addHandler
(
fh
)
logger
.
addHandler
(
fh
)
#level = os.environ.get('LOGLEVEL', 'DEBUG')
#logger.propagate = False
#logger.setLevel('INFO')
logger
.
setLevel
(
'INFO'
)
logger
.
info
(
"
%
s loaded"
,
__file__
)
logger
.
info
(
"
%
s loaded"
,
__file__
)
class
AppServerSvc
(
win32serviceutil
.
ServiceFramework
):
class
AppServerSvc
(
win32serviceutil
.
ServiceFramework
):
_svc_name_
=
"circle-agent"
_svc_name_
=
"circle-agent"
_svc_display_name_
=
"CIRCLE Agent"
_svc_display_name_
=
"CIRCLE Agent"
_svc_desciption_
=
"CIRCLE cloud contextualization agent"
_svc_desc
r
iption_
=
"CIRCLE cloud contextualization agent"
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
win32serviceutil
.
ServiceFramework
.
__init__
(
self
,
args
)
win32serviceutil
.
ServiceFramework
.
__init__
(
self
,
args
)
...
@@ -37,7 +38,12 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
...
@@ -37,7 +38,12 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def
SvcStop
(
self
):
def
SvcStop
(
self
):
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOP_PENDING
)
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOP_PENDING
)
win32event
.
SetEvent
(
self
.
hWaitStop
)
win32event
.
SetEvent
(
self
.
hWaitStop
)
try
:
reactor
.
stop
()
reactor
.
stop
()
reactor
.
callFromThread
(
reactor
.
stop
)
# reactor.callLater(0, reactor.stop)
except
Exception
:
logger
.
exception
(
"reactor.stop failed"
)
logger
.
info
(
"
%
s stopped"
,
__file__
)
logger
.
info
(
"
%
s stopped"
,
__file__
)
def
SvcDoRun
(
self
):
def
SvcDoRun
(
self
):
...
@@ -49,12 +55,25 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
...
@@ -49,12 +55,25 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
servicemanager
.
PYS_SERVICE_STARTED
,
servicemanager
.
PYS_SERVICE_STARTED
,
(
self
.
_svc_name_
,
''
))
(
self
.
_svc_name_
,
''
))
break
break
except
Exception
as
e
:
except
Exception
:
logger
.
exception
(
"Servicemanager busy?"
,
e
)
logger
.
exception
(
"Servicemanager busy?"
)
cnt
-=
1
cnt
-=
1
if
cnt
:
if
cnt
:
logger
.
info
(
"Starting agent_main"
)
logger
.
info
(
"Starting agent_main"
)
agent_main
()
ret
=
agent_main
()
logger
.
error
(
"agent_main returned ret=
%
r type=
%
s"
,
ret
,
type
(
ret
)
.
__name__
)
logger
.
info
(
"agent_main finished"
)
if
ret
!=
0
:
# “Service-specific error”
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOPPED
,
win32ExitCode
=
winerror
.
ERROR_SERVICE_SPECIFIC_ERROR
,
# 1066
svcExitCode
=
int
(
ret
)
)
return
# normal stop
self
.
ReportServiceStatus
(
win32service
.
SERVICE_STOPPED
)
def
main
():
def
main
():
...
...
agent.py
View file @
67869460
...
@@ -8,16 +8,21 @@ import subprocess
...
@@ -8,16 +8,21 @@ import subprocess
import
sys
import
sys
from
shutil
import
rmtree
from
shutil
import
rmtree
logging
.
basicConfig
(
from
logging.handlers
import
TimedRotatingFileHandler
format
=
"[
%(asctime)
s]
%(levelname)
s [agent
%(process)
d/
%(thread)
d]
%(module)
s.
%(funcName)
s:
%(lineno)
d]
%(message)
s"
,
from
pathlib
import
Path
datefmt
=
"
%
d/
%
b/
%
Y
%
H:
%
M:
%
S"
,
from
utils
import
setup_logging
)
logger
=
logging
.
getLogger
()
if
getattr
(
sys
,
"frozen"
,
False
):
logger
=
setup_logging
(
logfile
=
r"C:\Circle\agent.log"
)
else
:
logger
=
setup_logging
()
level
=
environ
.
get
(
'LOGLEVEL'
,
'INFO'
)
level
=
environ
.
get
(
'LOGLEVEL'
,
'INFO'
)
logger
.
setLevel
(
level
)
logger
.
setLevel
(
level
)
system
=
platform
.
system
()
# noqa
system
=
platform
.
system
()
# noqa
logger
.
debug
(
"system:
%
s"
,
system
)
logger
.
debug
(
"system:
%
s"
,
system
)
win
=
platform
.
system
()
==
"Windows"
if
len
(
sys
.
argv
)
!=
1
and
(
system
==
"Linux"
or
system
==
"FreeBSD"
):
# noqa
if
len
(
sys
.
argv
)
!=
1
and
(
system
==
"Linux"
or
system
==
"FreeBSD"
):
# noqa
logger
.
info
(
"Installing agent on
%
s system"
,
system
)
logger
.
info
(
"Installing agent on
%
s system"
,
system
)
...
@@ -39,9 +44,17 @@ import uptime
...
@@ -39,9 +44,17 @@ import uptime
from
inspect
import
getargs
,
isfunction
from
inspect
import
getargs
,
isfunction
from
utils
import
SerialLineReceiverBase
from
utils
import
SerialLineReceiverBase
# Note: Import everything because later we need to use the BaseContext
if
win
:
# (relative import error.
from
windows.winutils
import
getRegistryVal
,
get_windows_version
from
context
import
BaseContext
,
get_context
,
get_serial
# noqa
level
=
getRegistryVal
(
r"SYSTEM\\CurrentControlSet\\Services\\CIRCLE-agent\\Parameters"
,
"LogLevel"
,
level
)
logger
.
setLevel
(
level
)
# system = get_windows_version()
from
context
import
get_context
,
get_serial
# noqa
try
:
try
:
# Python 2: "unicode" is built-in
# Python 2: "unicode" is built-in
...
@@ -86,14 +99,14 @@ class SerialLineReceiver(SerialLineReceiverBase):
...
@@ -86,14 +99,14 @@ class SerialLineReceiver(SerialLineReceiverBase):
self
.
send_command
(
command
=
'agent_stopped'
,
args
=
{})
self
.
send_command
(
command
=
'agent_stopped'
,
args
=
{})
def
mayStartNow
(
self
):
def
mayStartNow
(
self
):
if
Base
Context
.
placed
:
if
Context
.
placed
:
self
.
mayStartNowId
.
stop
()
self
.
mayStartNowId
.
stop
()
logger
.
info
(
"Placed"
)
logger
.
info
(
"Placed"
)
return
return
self
.
send_startMsg
()
self
.
send_startMsg
()
def
tick
(
self
):
def
tick
(
self
):
logger
.
debug
(
"Sending tick"
)
#
logger.debug("Sending tick")
try
:
try
:
self
.
send_status
()
self
.
send_status
()
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -109,7 +122,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
...
@@ -109,7 +122,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
return
d
return
d
def
send_startMsg
(
self
):
def
send_startMsg
(
self
):
logger
.
debug
(
"Sending start message
..."
)
logger
.
debug
(
"Sending start message
:
%
s
%
s"
,
Context
.
get_agent_version
(),
system
)
# Hack for flushing the lower level buffersr
# Hack for flushing the lower level buffersr
self
.
transport
.
dataBuffer
=
b
""
self
.
transport
.
dataBuffer
=
b
""
self
.
transport
.
_tempDataBuffer
=
[]
# will be added to dataBuffer in doWrite
self
.
transport
.
_tempDataBuffer
=
[]
# will be added to dataBuffer in doWrite
...
@@ -132,10 +145,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
...
@@ -132,10 +145,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
"disk"
:
disk_usage
,
"disk"
:
disk_usage
,
"user"
:
{
"count"
:
len
(
psutil
.
users
())}}
"user"
:
{
"count"
:
len
(
psutil
.
users
())}}
self
.
send_response
(
response
=
'status'
,
args
=
args
)
self
.
send_response
(
response
=
'status'
,
args
=
args
)
logger
.
debug
(
"send_status finished"
)
#
logger.debug("send_status finished")
def
_check_args
(
self
,
func
,
args
):
def
_check_args
(
self
,
func
,
args
):
logger
.
debug
(
"_check_args
%
s
%
s"
%
(
func
,
args
))
#
logger.debug("_check_args %s %s" % (func, args))
if
not
isinstance
(
args
,
dict
):
if
not
isinstance
(
args
,
dict
):
raise
TypeError
(
"Arguments should be all keyword-arguments in a "
raise
TypeError
(
"Arguments should be all keyword-arguments in a "
"dict for command
%
s instead of
%
s."
%
"dict for command
%
s instead of
%
s."
%
...
@@ -161,10 +174,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
...
@@ -161,10 +174,10 @@ class SerialLineReceiver(SerialLineReceiverBase):
if
missing_kwargs
:
if
missing_kwargs
:
raise
TypeError
(
"Command
%
s missing arguments:
%
s"
%
(
raise
TypeError
(
"Command
%
s missing arguments:
%
s"
%
(
self
.
_pretty_fun
(
func
),
", "
.
join
(
missing_kwargs
)))
self
.
_pretty_fun
(
func
),
", "
.
join
(
missing_kwargs
)))
logger
.
debug
(
"_check_args finished"
)
#
logger.debug("_check_args finished")
def
_get_command
(
self
,
command
,
args
):
def
_get_command
(
self
,
command
,
args
):
logger
.
debug
(
"_get_command
%
s
%
s"
%
(
command
,
args
)
)
# logger.debug("_get_command %s" % command
)
if
not
isinstance
(
command
,
unicode
)
or
command
.
startswith
(
'_'
):
if
not
isinstance
(
command
,
unicode
)
or
command
.
startswith
(
'_'
):
raise
AttributeError
(
u'Invalid command:
%
s'
%
command
)
raise
AttributeError
(
u'Invalid command:
%
s'
%
command
)
try
:
try
:
...
@@ -177,7 +190,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
...
@@ -177,7 +190,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
self
.
_pretty_fun
(
func
))
self
.
_pretty_fun
(
func
))
self
.
_check_args
(
func
,
args
)
self
.
_check_args
(
func
,
args
)
logger
.
debug
(
"_get_command finished"
)
#
logger.debug("_get_command finished")
return
func
return
func
@staticmethod
@staticmethod
...
@@ -194,9 +207,9 @@ class SerialLineReceiver(SerialLineReceiverBase):
...
@@ -194,9 +207,9 @@ class SerialLineReceiver(SerialLineReceiverBase):
return
"<
%
s>"
%
type
(
fun
)
.
__name__
return
"<
%
s>"
%
type
(
fun
)
.
__name__
def
handle_command
(
self
,
command
,
args
):
def
handle_command
(
self
,
command
,
args
):
logger
.
debug
(
"handle_command
%
s
%
s"
%
(
command
,
args
)
)
logger
.
debug
(
"handle_command
%
s
"
%
command
)
func
=
self
.
_get_command
(
command
,
args
)
func
=
self
.
_get_command
(
command
,
args
)
logger
.
debug
(
"Call cmd:
%
s
%
s"
%
(
func
,
args
)
)
# logger.debug("Call cmd: %s" % func
)
retval
=
func
(
**
args
)
retval
=
func
(
**
args
)
logger
.
debug
(
"Retval:
%
s"
%
retval
)
logger
.
debug
(
"Retval:
%
s"
%
retval
)
self
.
send_response
(
self
.
send_response
(
...
@@ -207,6 +220,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
...
@@ -207,6 +220,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
pass
pass
def
main
():
def
main
():
if
Context
.
postUpdate
():
# Service updated, Restart needed
return
1
# Get proper serial class and port name
# Get proper serial class and port name
(
serial
,
port
)
=
get_serial
()
(
serial
,
port
)
=
get_serial
()
logger
.
info
(
"Opening port
%
s"
,
port
)
logger
.
info
(
"Opening port
%
s"
,
port
)
...
@@ -220,7 +238,7 @@ def main():
...
@@ -220,7 +238,7 @@ def main():
logger
.
debug
(
"Starting reactor."
)
logger
.
debug
(
"Starting reactor."
)
reactor
.
run
()
reactor
.
run
()
logger
.
debug
(
"Reactor finished."
)
logger
.
debug
(
"Reactor finished."
)
return
Context
.
exit_code
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
circle-notify.pyw
View file @
67869460
...
@@ -3,15 +3,20 @@
...
@@ -3,15 +3,20 @@
# Should be in autostart and run by the user logged in
# Should be in autostart and run by the user logged in
import
logging
import
logging
from
os.path
import
join
from
os
import
environ
from
notify
import
run_client
,
get_temp_dir
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
fh
=
logging
.
FileHandler
(
"agent-client.log"
)
logfile
=
join
(
get_temp_dir
(),
"agent-client.log"
)
fh
=
logging
.
FileHandler
(
logfile
)
formatter
=
logging
.
Formatter
(
formatter
=
logging
.
Formatter
(
"
%(asctime)
s -
%(name)
s [
%(levelname)
s]
%(message)
s"
)
"
%(asctime)
s -
%(name)
s [
%(levelname)
s]
%(message)
s"
)
fh
.
setFormatter
(
formatter
)
fh
.
setFormatter
(
formatter
)
logger
.
addHandler
(
fh
)
logger
.
addHandler
(
fh
)
level
=
environ
.
get
(
'LOGLEVEL'
,
'INFO'
)
logger
.
setLevel
(
level
)
from
notify
import
run_client
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
run_client
()
run_client
()
context.py
View file @
67869460
...
@@ -5,7 +5,21 @@ import logging
...
@@ -5,7 +5,21 @@ import logging
import
platform
import
platform
import
re
import
re
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
__name__
)
# --- compatibility patch for old libs expecting inspect.getargspec (Py<3.11) ---
import
inspect
from
collections
import
namedtuple
if
not
hasattr
(
inspect
,
"getargspec"
):
ArgSpec
=
namedtuple
(
"ArgSpec"
,
"args varargs keywords defaults"
)
def
getargspec
(
func
):
fs
=
inspect
.
getfullargspec
(
func
)
return
ArgSpec
(
fs
.
args
,
fs
.
varargs
,
fs
.
varkw
,
fs
.
defaults
)
inspect
.
getargspec
=
getargspec
# ---------------------------------------------------------------------------
def
_get_virtio_device
():
def
_get_virtio_device
():
path
=
None
path
=
None
...
@@ -78,6 +92,11 @@ def get_serial():
...
@@ -78,6 +92,11 @@ def get_serial():
class
BaseContext
(
object
):
class
BaseContext
(
object
):
placed
=
False
# if we reciwed password or net commands
placed
=
False
# if we reciwed password or net commands
exit_code
=
0
@staticmethod
def
postUpdate
():
return
false
@staticmethod
@staticmethod
def
change_password
(
password
):
def
change_password
(
password
):
...
...
notify.py
View file @
67869460
...
@@ -34,7 +34,7 @@ try:
...
@@ -34,7 +34,7 @@ try:
except
NameError
:
except
NameError
:
unicode
=
str
unicode
=
str
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
__name__
)
logger
.
debug
(
"notify imported"
)
logger
.
debug
(
"notify imported"
)
file_name
=
"vm_renewal.json"
file_name
=
"vm_renewal.json"
win
=
platform
.
system
()
==
"Windows"
win
=
platform
.
system
()
==
"Windows"
...
@@ -51,6 +51,8 @@ def parse_arguments():
...
@@ -51,6 +51,8 @@ def parse_arguments():
def
get_temp_dir
():
def
get_temp_dir
():
if
os
.
getenv
(
"TMPDIR"
):
if
os
.
getenv
(
"TMPDIR"
):
temp_dir
=
os
.
getenv
(
"TMPDIR"
)
temp_dir
=
os
.
getenv
(
"TMPDIR"
)
elif
os
.
getenv
(
"TEMP"
):
temp_dir
=
os
.
getenv
(
"TEMP"
)
elif
os
.
getenv
(
"TMP"
):
elif
os
.
getenv
(
"TMP"
):
temp_dir
=
os
.
getenv
(
"TMP"
)
temp_dir
=
os
.
getenv
(
"TMP"
)
elif
os
.
path
.
exists
(
"/tmp"
):
elif
os
.
path
.
exists
(
"/tmp"
):
...
@@ -72,10 +74,11 @@ def wall(text):
...
@@ -72,10 +74,11 @@ def wall(text):
process
.
communicate
(
input
=
text
)[
0
]
process
.
communicate
(
input
=
text
)[
0
]
def
accept
():
def
accept
(
url
=
None
):
import
datetime
import
datetime
from
tzlocal
import
get_localzone
from
tzlocal
import
get_localzone
from
pytz
import
UTC
from
pytz
import
UTC
if
url
==
None
:
file_path
=
os
.
path
.
join
(
get_temp_dir
(),
file_name
)
file_path
=
os
.
path
.
join
(
get_temp_dir
(),
file_name
)
if
not
os
.
path
.
isfile
(
file_path
):
if
not
os
.
path
.
isfile
(
file_path
):
print
(
"There is no recent notification to accept."
)
print
(
"There is no recent notification to accept."
)
...
@@ -83,9 +86,13 @@ def accept():
...
@@ -83,9 +86,13 @@ def accept():
# Load the saved url
# Load the saved url
url
=
json
.
load
(
open
(
file_path
,
"r"
))
url
=
json
.
load
(
open
(
file_path
,
"r"
))
os
.
remove
(
file_path
)
cj
=
cookielib
.
CookieJar
()
cj
=
cookielib
.
CookieJar
()
opener
=
urllib2
.
build_opener
(
urllib2
.
HTTPCookieProcessor
(
cj
))
opener
=
urllib2
.
build_opener
(
urllib2
.
HTTPCookieProcessor
(
cj
))
msh
=
None
ret
=
False
new_local_time
=
None
try
:
try
:
opener
.
open
(
url
)
# GET to collect cookies
opener
.
open
(
url
)
# GET to collect cookies
cookies
=
cj
.
_cookies_for_request
(
urllib2
.
Request
(
url
))
cookies
=
cj
.
_cookies_for_request
(
urllib2
.
Request
(
url
))
...
@@ -95,25 +102,23 @@ def accept():
...
@@ -95,25 +102,23 @@ def accept():
b
"x-csrftoken"
:
token
})
b
"x-csrftoken"
:
token
})
rsp
=
opener
.
open
(
req
)
rsp
=
opener
.
open
(
req
)
data
=
json
.
load
(
rsp
)
data
=
json
.
load
(
rsp
)
logger
.
debug
(
"data
%
r"
,
data
)
newtime
=
data
[
"new_suspend_time"
]
newtime
=
data
[
"new_suspend_time"
]
# Parse time from JSON (Create UTC Localized Datetime objec)
msg
=
data
[
"message"
]
parsed_time
=
datetime
.
datetime
.
strptime
(
# # Parse time from JSON (Create UTC Localized Datetime objec)
newtime
[:
-
6
],
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S.
%
f"
)
.
replace
(
tzinfo
=
UTC
)
# parsed_time = datetime.datetime.strptime(
# Convert to the machine localization
# newtime[:-6], "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=UTC)
new_local_time
=
parsed_time
.
astimezone
(
# # Convert to the machine localization
get_localzone
())
.
strftime
(
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S"
)
# new_local_time = parsed_time.astimezone(
# get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except
ValueError
as
e
:
except
ValueError
as
e
:
print
(
"Parsing time failed:
%
s"
%
e
)
msg
=
"Parsing time failed:
%
s"
%
e
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
msg
=
"Renewal failed. Please try it manually at
%
s"
%
url
print
(
"Renewal failed. Please try it manually at
%
s"
%
url
)
logger
.
exception
(
"renew failed"
)
logger
.
exception
(
"renew failed"
)
return
False
else
:
else
:
print
(
"Renew succeeded. The machine will be "
ret
=
True
"suspended at
%
s."
%
new_local_time
)
return
{
'ret'
:
ret
,
'msg'
:
msg
,
'new_local_time'
:
new_local_time
}
os
.
remove
(
file_path
)
return
True
def
notify
(
url
):
def
notify
(
url
):
...
@@ -220,6 +225,8 @@ def search_display():
...
@@ -220,6 +225,8 @@ def search_display():
if
win
:
if
win
:
from
twisted.internet
import
protocol
from
twisted.internet
import
protocol
from
twisted.protocols
import
basic
from
twisted.protocols
import
basic
from
winotify
import
Notification
from
prompt
import
prompt_yes_no
clients
=
set
()
clients
=
set
()
port
=
25683
port
=
25683
...
@@ -260,13 +267,40 @@ if win:
...
@@ -260,13 +267,40 @@ if win:
def
lineReceived
(
self
,
line
):
def
lineReceived
(
self
,
line
):
logger
.
debug
(
"received
%
s
%
s"
%
(
line
,
type
(
line
)))
#
logger.debug("received %s %s" % (line, type(line)))
if
not
isinstance
(
line
,
str
):
if
not
isinstance
(
line
,
str
):
line
=
line
.
decode
()
line
=
line
.
decode
()
if
line
.
startswith
(
'cifs://'
):
if
line
.
startswith
(
'cifs://'
):
mount_smb
(
line
)
mount_smb
(
line
)
else
:
else
:
open_in_browser
(
line
)
file_path
=
os
.
path
.
join
(
get_temp_dir
(),
file_name
)
if
file_already_exists
(
file_path
):
os
.
remove
(
file_path
)
if
file_already_exists
(
file_path
):
raise
Exception
(
"Couldn't create file
%
s as new"
%
file_path
)
with
open
(
file_path
,
"w"
)
as
f
:
json
.
dump
(
line
,
f
)
# open_in_browser(line)
toast
=
Notification
(
app_id
=
"CIRCLE Agent"
,
title
=
"VM expiration"
,
msg
=
"VM expiring soon"
,
duration
=
"long"
)
toast
.
add_actions
(
label
=
"renrew"
,
launch
=
line
)
toast
.
show
()
ans
=
prompt_yes_no
(
title
=
"Warning"
,
message
=
"This VM expiring soon"
,
yes_label
=
"Renew"
,
no_label
=
"Cancel"
,
timeout_seconds
=
60
,
)
if
ans
==
"yes"
:
ret
=
accept
(
line
)
prompt_yes_no
(
title
=
"Info"
,
message
=
ret
[
'msg'
],
yes_label
=
"OK"
,
no_label
=
""
,
timeout_seconds
=
10
if
ret
[
'ret'
]
else
60
)
class
SubFactory
(
protocol
.
ReconnectingClientFactory
):
class
SubFactory
(
protocol
.
ReconnectingClientFactory
):
...
...
prompt.py
0 → 100644
View file @
67869460
# prompt.py
# Tkinter modal prompt: 1 or 2 buttons, optional timeout.
# Thread-free (Tkinter-safe). Good for PyInstaller.
from
__future__
import
annotations
from
typing
import
Optional
,
Literal
Result
=
Literal
[
"yes"
,
"no"
,
"timeout"
]
def
prompt_yes_no
(
title
:
str
,
message
:
str
,
yes_label
:
str
=
"Yes"
,
no_label
:
str
=
"No"
,
# if empty/whitespace => single-button mode
default
:
Literal
[
"yes"
,
"no"
]
=
"yes"
,
timeout_seconds
:
Optional
[
int
]
=
None
,
show_countdown
:
bool
=
True
,
topmost
:
bool
=
True
,
wrap_width
:
int
=
420
,
)
->
Result
:
import
tkinter
as
tk
from
tkinter
import
ttk
single_button
=
(
no_label
.
strip
()
==
""
)
# Default result if user closes window (X) or ESC
result_value
:
Result
=
"no"
root
=
tk
.
Tk
()
root
.
title
(
title
)
root
.
resizable
(
False
,
False
)
if
topmost
:
root
.
attributes
(
"-topmost"
,
True
)
frame
=
ttk
.
Frame
(
root
,
padding
=
16
)
frame
.
grid
(
row
=
0
,
column
=
0
)
lbl
=
ttk
.
Label
(
frame
,
text
=
message
,
justify
=
"left"
,
wraplength
=
wrap_width
)
lbl
.
grid
(
row
=
0
,
column
=
0
,
columnspan
=
2
,
sticky
=
"w"
)
countdown_var
=
tk
.
StringVar
(
value
=
""
)
countdown_lbl
=
ttk
.
Label
(
frame
,
textvariable
=
countdown_var
)
countdown_lbl
.
grid
(
row
=
1
,
column
=
0
,
columnspan
=
2
,
sticky
=
"w"
,
pady
=
(
8
,
0
))
after_id
=
None
cancelled
=
False
def
cancel_timer
():
nonlocal
after_id
if
after_id
is
not
None
:
try
:
root
.
after_cancel
(
after_id
)
except
Exception
:
pass
after_id
=
None
def
finish
(
value
:
Result
):
nonlocal
cancelled
,
result_value
cancelled
=
True
result_value
=
value
cancel_timer
()
try
:
root
.
destroy
()
except
Exception
:
pass
def
on_yes
():
finish
(
"yes"
)
def
on_no
():
finish
(
"no"
)
# Buttons row
btn_row
=
2
btn_yes
=
ttk
.
Button
(
frame
,
text
=
yes_label
,
command
=
on_yes
)
btn_yes
.
grid
(
row
=
btn_row
,
column
=
0
,
padx
=
6
,
pady
=
(
16
,
0
))
btn_no
=
None
if
not
single_button
:
btn_no
=
ttk
.
Button
(
frame
,
text
=
no_label
,
command
=
on_no
)
btn_no
.
grid
(
row
=
btn_row
,
column
=
1
,
padx
=
6
,
pady
=
(
16
,
0
))
else
:
# In single-button mode, span across
btn_yes
.
grid_configure
(
column
=
0
,
columnspan
=
2
)
# Default focus
if
default
==
"yes"
or
single_button
:
btn_yes
.
focus_set
()
else
:
if
btn_no
is
not
None
:
btn_no
.
focus_set
()
# Window close acts like "no" (safe default)
root
.
protocol
(
"WM_DELETE_WINDOW"
,
on_no
)
# Keyboard shortcuts:
# ESC => No
root
.
bind
(
"<Escape>"
,
lambda
_evt
:
on_no
())
# Enter => default (or the only button)
if
single_button
or
default
==
"yes"
:
root
.
bind
(
"<Return>"
,
lambda
_evt
:
on_yes
())
else
:
root
.
bind
(
"<Return>"
,
lambda
_evt
:
on_no
())
# Center window
root
.
update_idletasks
()
w
,
h
=
root
.
winfo_width
(),
root
.
winfo_height
()
sw
,
sh
=
root
.
winfo_screenwidth
(),
root
.
winfo_screenheight
()
root
.
geometry
(
f
"+{(sw - w)//2}+{(sh - h)//3}"
)
# Timeout (thread-free)
if
timeout_seconds
and
timeout_seconds
>
0
:
remaining
=
int
(
timeout_seconds
)
def
tick
():
nonlocal
remaining
,
after_id
if
cancelled
or
not
root
.
winfo_exists
():
return
if
remaining
<=
0
:
finish
(
"timeout"
)
return
if
show_countdown
:
countdown_var
.
set
(
f
"Auto close in {remaining} seconds…"
)
remaining
-=
1
after_id
=
root
.
after
(
1000
,
tick
)
after_id
=
root
.
after
(
1000
,
tick
)
else
:
# If no timeout, hide countdown line (optional)
if
not
show_countdown
:
countdown_var
.
set
(
""
)
root
.
mainloop
()
return
result_value
if
__name__
==
"__main__"
:
# Demo
r
=
prompt_yes_no
(
title
=
"VM expiring"
,
message
=
"A VM 5 percen belul suspend lesz.
\n\n
Szeretned meghosszabbitani?"
,
yes_label
=
"Renew"
,
no_label
=
"Cancel"
,
timeout_seconds
=
20
,
)
print
(
"Result:"
,
r
)
r2
=
prompt_yes_no
(
title
=
"Info"
,
message
=
"This is a single-button message."
,
yes_label
=
"OK"
,
no_label
=
""
,
timeout_seconds
=
10
,
)
print
(
"Result2:"
,
r2
)
requirements/windows.txt
View file @
67869460
...
@@ -10,4 +10,5 @@ tzlocal
...
@@ -10,4 +10,5 @@ tzlocal
pytz
pytz
pywin32
pywin32
wmi
wmi
winotify
utils.py
View file @
67869460
from
twisted.protocols.basic
import
LineReceiver
from
twisted.protocols.basic
import
LineReceiver
import
sys
import
json
import
json
import
logging
import
logging
from
logging.handlers
import
TimedRotatingFileHandler
import
platform
import
platform
from
os
import
chmod
from
os
import
chmod
from
shutil
import
copyfile
from
shutil
import
copyfile
...
@@ -11,9 +13,60 @@ try:
...
@@ -11,9 +13,60 @@ try:
except
NameError
:
except
NameError
:
unicode
=
str
unicode
=
str
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
__name__
)
system
=
platform
.
system
()
system
=
platform
.
system
()
def
setup_logging
(
logfile
=
None
,
backup_count
=
3
):
logger
=
logging
.
getLogger
()
logger
.
handlers
.
clear
()
formatter
=
logging
.
Formatter
(
"[
%(asctime)
s]
%(levelname)
s [agent
%(process)
d/
%(thread)
d] "
"
%(module)
s.
%(funcName)
s:
%(lineno)
d]
%(message)
s"
,
"
%
d/
%
b/
%
Y
%
H:
%
M:
%
S"
,
)
if
logfile
!=
None
:
handler
=
TimedRotatingFileHandler
(
filename
=
logfile
,
when
=
"midnight"
,
backupCount
=
backup_count
,
encoding
=
"utf-8"
,
delay
=
True
,
)
else
:
handler
=
logging
.
StreamHandler
(
sys
.
stderr
)
handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
handler
)
return
logger
def
setup_logging_timed
(
logfile
:
str
,
level
:
str
,
backup_count
:
int
):
handler
=
TimedRotatingFileHandler
(
filename
=
str
(
log_path
),
when
=
"midnight"
,
interval
=
1
,
backupCount
=
backup_count
,
encoding
=
"utf-8"
,
utc
=
False
,
# use local time (Budapest)
delay
=
True
,
# create file only when first log is emitted
)
formatter
=
logging
.
Formatter
(
format
=
"[
%(asctime)
s]
%(levelname)
s [agent
%(process)
d/
%(thread)
d] "
"
%(module)
s.
%(funcName)
s:
%(lineno)
d]
%(message)
s"
,
datefmt
=
"
%
d/
%
b/
%
Y
%
H:
%
M:
%
S"
,
)
handler
.
setFormatter
(
formatter
)
root
=
logging
.
getLogger
()
root
.
setLevel
(
level
)
# Avoid duplicate handlers if setup called multiple times
root
.
handlers
.
clear
()
root
.
addHandler
(
handler
)
return
root
class
SerialLineReceiverBase
(
LineReceiver
,
object
):
class
SerialLineReceiverBase
(
LineReceiver
,
object
):
MAX_LENGTH
=
1024
*
1024
*
128
MAX_LENGTH
=
1024
*
1024
*
128
...
@@ -26,7 +79,7 @@ class SerialLineReceiverBase(LineReceiver, object):
...
@@ -26,7 +79,7 @@ class SerialLineReceiverBase(LineReceiver, object):
super
(
SerialLineReceiverBase
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
super
(
SerialLineReceiverBase
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
def
send_response
(
self
,
response
,
args
):
def
send_response
(
self
,
response
,
args
):
logger
.
debug
(
"send_response
%
s
%
s"
%
(
response
,
args
))
#
logger.debug("send_response %s %s" % (response, args))
self
.
transport
.
write
(
json
.
dumps
({
'response'
:
response
,
self
.
transport
.
write
(
json
.
dumps
({
'response'
:
response
,
'args'
:
args
})
+
'
\r\n
'
)
'args'
:
args
})
+
'
\r\n
'
)
...
@@ -41,7 +94,7 @@ class SerialLineReceiverBase(LineReceiver, object):
...
@@ -41,7 +94,7 @@ class SerialLineReceiverBase(LineReceiver, object):
raise
NotImplementedError
(
"Subclass must implement abstract method"
)
raise
NotImplementedError
(
"Subclass must implement abstract method"
)
def
lineReceived
(
self
,
data
):
def
lineReceived
(
self
,
data
):
logger
.
debug
(
"lineReceived:
%
s"
,
data
)
#
logger.debug("lineReceived: %s", data)
if
(
isinstance
(
data
,
unicode
)):
if
(
isinstance
(
data
,
unicode
)):
data
=
data
.
strip
(
'
\0
'
)
data
=
data
.
strip
(
'
\0
'
)
else
:
else
:
...
@@ -53,14 +106,14 @@ class SerialLineReceiverBase(LineReceiver, object):
...
@@ -53,14 +106,14 @@ class SerialLineReceiverBase(LineReceiver, object):
args
=
{}
args
=
{}
command
=
data
.
get
(
'command'
,
None
)
command
=
data
.
get
(
'command'
,
None
)
response
=
data
.
get
(
'response'
,
None
)
response
=
data
.
get
(
'response'
,
None
)
logger
.
debug
(
'[serial] valid json:
%
s'
%
(
data
,
))
#
logger.debug('[serial] valid json: %s' % (data, ))
except
(
ValueError
,
KeyError
)
as
e
:
except
(
ValueError
,
KeyError
)
as
e
:
logger
.
error
(
'[serial] invalid json:
%
s (
%
s)'
%
(
data
,
e
))
#
logger.error('[serial] invalid json: %s (%s)' % (data, e))
self
.
clearLineBuffer
()
self
.
clearLineBuffer
()
return
return
if
command
is
not
None
and
isinstance
(
command
,
unicode
):
if
command
is
not
None
and
isinstance
(
command
,
unicode
):
logger
.
debug
(
'received command:
%
s (
%
s)'
%
(
command
,
args
))
# logger.debug('received command: %s (%s)' % (command, args[:10]
))
try
:
try
:
self
.
handle_command
(
command
,
args
)
self
.
handle_command
(
command
,
args
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -93,3 +146,4 @@ def copy_file(src, dst, overw=False, mode=None):
...
@@ -93,3 +146,4 @@ def copy_file(src, dst, overw=False, mode=None):
return
copyed
return
copyed
win_build.bat
View file @
67869460
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F watchdog-winservice.py
pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-wdog-winservice.py
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-winservice.py
pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F agent-winservice.py
pyinstaller --clean --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F circle-notify.pyw
pyinstaller --clean -F --path . --hidden-import pkg_resources --hidden-import infi --hidden-import win32timezone --hidden-import win32traceutil -F circle-notify.pyw
\ No newline at end of file
\ No newline at end of file
windows/_win32context.py
View file @
67869460
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
working_directory
=
r"C:\circle"
# noqa
from
os.path
import
join
from
os.path
import
join
import
sys
import
logging
import
logging
import
tarfile
import
tarfile
from
io
import
BytesIO
from
io
import
BytesIO
...
@@ -19,6 +18,10 @@ from twisted.internet import reactor
...
@@ -19,6 +18,10 @@ from twisted.internet import reactor
from
.network
import
change_ip_windows
from
.network
import
change_ip_windows
from
context
import
BaseContext
from
context
import
BaseContext
from
windows.winutils
import
(
is_frozen_exe
,
copy_running_exe
,
update_service_binpath
,
servicePostUpdate
)
try
:
try
:
# Python 2: "unicode" is built-in
# Python 2: "unicode" is built-in
...
@@ -26,14 +29,22 @@ try:
...
@@ -26,14 +29,22 @@ try:
except
NameError
:
except
NameError
:
unicode
=
str
unicode
=
str
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
__name__
)
class
Context
(
BaseContext
):
class
Context
(
BaseContext
):
service_name
=
"CIRCLE-agent"
working_dir
=
r"C:\circle"
exe
=
"circle-agent.exe"
@staticmethod
def
postUpdate
():
exe_path
=
join
(
Context
.
working_dir
,
Context
.
exe
)
return
servicePostUpdate
(
Context
.
service_name
,
exe_path
)
@staticmethod
@staticmethod
def
change_password
(
password
):
def
change_password
(
password
):
Base
Context
.
placed
=
True
Context
.
placed
=
True
from
win32com
import
adsi
from
win32com
import
adsi
ads_obj
=
adsi
.
ADsGetObject
(
'WinNT://localhost/
%
s,user'
%
'cloud'
)
ads_obj
=
adsi
.
ADsGetObject
(
'WinNT://localhost/
%
s,user'
%
'cloud'
)
ads_obj
.
Getinfo
()
ads_obj
.
Getinfo
()
...
@@ -45,7 +56,7 @@ class Context(BaseContext):
...
@@ -45,7 +56,7 @@ class Context(BaseContext):
@staticmethod
@staticmethod
def
change_ip
(
interfaces
,
dns
):
def
change_ip
(
interfaces
,
dns
):
Base
Context
.
placed
=
True
Context
.
placed
=
True
nameservers
=
dns
.
replace
(
' '
,
''
)
.
split
(
','
)
nameservers
=
dns
.
replace
(
' '
,
''
)
.
split
(
','
)
change_ip_windows
(
interfaces
,
nameservers
)
change_ip_windows
(
interfaces
,
nameservers
)
...
@@ -99,19 +110,6 @@ class Context(BaseContext):
...
@@ -99,19 +110,6 @@ class Context(BaseContext):
myfile
.
write
(
data
)
myfile
.
write
(
data
)
@staticmethod
@staticmethod
def
_update_registry
(
dir
,
executable
):
# HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\circle-agent
from
winreg
import
(
OpenKeyEx
,
SetValueEx
,
QueryValueEx
,
HKEY_LOCAL_MACHINE
,
KEY_ALL_ACCESS
)
with
OpenKeyEx
(
HKEY_LOCAL_MACHINE
,
r'SYSTEM\CurrentControlSet\services\circle-agent'
,
0
,
KEY_ALL_ACCESS
)
as
key
:
(
old_executable
,
reg_type
)
=
QueryValueEx
(
key
,
"ImagePath"
)
SetValueEx
(
key
,
"ImagePath"
,
None
,
2
,
join
(
dir
,
executable
))
return
old_executable
@staticmethod
def
update
(
filename
,
executable
,
checksum
,
uuid
):
def
update
(
filename
,
executable
,
checksum
,
uuid
):
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
,
"r"
)
as
f
:
data
=
f
.
read
()
data
=
f
.
read
()
...
@@ -122,13 +120,14 @@ class Context(BaseContext):
...
@@ -122,13 +120,14 @@ class Context(BaseContext):
decoded
=
BytesIO
(
b64decode
(
data
))
decoded
=
BytesIO
(
b64decode
(
data
))
try
:
try
:
tar
=
tarfile
.
TarFile
.
open
(
"dummy"
,
fileobj
=
decoded
,
mode
=
'r|gz'
)
tar
=
tarfile
.
TarFile
.
open
(
"dummy"
,
fileobj
=
decoded
,
mode
=
'r|gz'
)
tar
.
extractall
(
working_directory
)
tar
.
extractall
(
Context
.
working_dir
)
except
tarfile
.
ReadError
as
e
:
except
tarfile
.
ReadError
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
logger
.
info
(
"Transfer completed!"
)
logger
.
info
(
"Transfer completed!"
)
old_exe
=
Context
.
_update_registry
(
working_directory
,
executable
)
old_exe
=
update_service_binpath
(
"CIRCLE-agent"
,
join
(
Context
.
working_dir
,
executable
)
)
logger
.
info
(
'
%
s Updated'
,
old_exe
)
logger
.
info
(
'
%
s Updated'
,
old_exe
)
reactor
.
stop
()
Context
.
exit_code
=
1
reactor
.
callLater
(
0
,
reactor
.
stop
)
@staticmethod
@staticmethod
def
ipaddresses
():
def
ipaddresses
():
...
@@ -149,7 +148,7 @@ class Context(BaseContext):
...
@@ -149,7 +148,7 @@ class Context(BaseContext):
@staticmethod
@staticmethod
def
get_agent_version
():
def
get_agent_version
():
try
:
try
:
with
open
(
join
(
working_directory
,
'version.txt'
))
as
f
:
with
open
(
join
(
Context
.
working_dir
,
'version.txt'
))
as
f
:
return
f
.
readline
()
return
f
.
readline
()
except
IOError
:
except
IOError
:
return
None
return
None
windows/win32virtio.py
View file @
67869460
...
@@ -20,7 +20,7 @@ from twisted.internet import abstract
...
@@ -20,7 +20,7 @@ from twisted.internet import abstract
# sibling imports
# sibling imports
import
logging
import
logging
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
__name__
)
class
SerialPort
(
abstract
.
FileDescriptor
):
class
SerialPort
(
abstract
.
FileDescriptor
):
...
@@ -68,7 +68,7 @@ class SerialPort(abstract.FileDescriptor):
...
@@ -68,7 +68,7 @@ class SerialPort(abstract.FileDescriptor):
self
.
_overlappedRead
)
self
.
_overlappedRead
)
def
serialReadEvent
(
self
):
def
serialReadEvent
(
self
):
logger
.
debug
(
"serialReadEvent
%
s
%
s"
%
(
self
.
_overlappedRead
.
Internal
,
self
.
_overlappedRead
.
InternalHigh
))
#
logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
try
:
try
:
n
=
win32file
.
GetOverlappedResult
(
self
.
hComPort
,
self
.
_overlappedRead
,
1
)
n
=
win32file
.
GetOverlappedResult
(
self
.
hComPort
,
self
.
_overlappedRead
,
1
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -89,14 +89,14 @@ class SerialPort(abstract.FileDescriptor):
...
@@ -89,14 +89,14 @@ class SerialPort(abstract.FileDescriptor):
data
=
str
.
encode
(
data
)
data
=
str
.
encode
(
data
)
if
self
.
writeInProgress
:
if
self
.
writeInProgress
:
self
.
outQueue
.
append
(
data
)
self
.
outQueue
.
append
(
data
)
logger
.
debug
(
"added to queue"
)
#
logger.debug("added to queue")
else
:
else
:
self
.
writeInProgress
=
1
self
.
writeInProgress
=
1
ret
,
n
=
win32file
.
WriteFile
(
self
.
hComPort
,
data
,
self
.
_overlappedWrite
)
ret
,
n
=
win32file
.
WriteFile
(
self
.
hComPort
,
data
,
self
.
_overlappedWrite
)
logger
.
debug
(
"Writed to file
%
s"
,
ret
)
#
logger.debug("Writed to file %s", ret)
def
serialWriteEvent
(
self
):
def
serialWriteEvent
(
self
):
logger
.
debug
(
"serialWriteEvent
%
s
%
s"
%
(
self
.
_overlappedWrite
.
Internal
,
self
.
_overlappedWrite
.
InternalHigh
))
#
logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
if
self
.
_overlappedWrite
.
Internal
<
0
and
self
.
_overlappedWrite
.
InternalHigh
==
0
:
# DANGER: Not documented variables
if
self
.
_overlappedWrite
.
Internal
<
0
and
self
.
_overlappedWrite
.
InternalHigh
==
0
:
# DANGER: Not documented variables
logger
.
debug
(
self
.
connLost
())
logger
.
debug
(
self
.
connLost
())
self
.
writeInProgress
=
0
self
.
writeInProgress
=
0
...
...
windows/winutils.py
0 → 100644
View file @
67869460
import
os
import
sys
import
logging
from
shutil
import
copy
from
os.path
import
join
,
normcase
,
normpath
from
winreg
import
(
OpenKeyEx
,
QueryValueEx
,
SetValueEx
,
HKEY_LOCAL_MACHINE
,
KEY_ALL_ACCESS
,
KEY_READ
)
logger
=
logging
.
getLogger
()
def
is_frozen_exe
()
->
bool
:
return
bool
(
getattr
(
sys
,
"frozen"
,
False
))
def
update_service_binpath
(
service_name
:
str
,
exe_path
:
str
)
->
str
:
"""
Update service ImagePath in registry to point to exe_path.
Returns the previous ImagePath string.
"""
with
OpenKeyEx
(
HKEY_LOCAL_MACHINE
,
fr
"SYSTEM
\
CurrentControlSet
\
services
\
{service_name}"
,
0
,
KEY_ALL_ACCESS
)
as
key
:
(
old_executable
,
reg_type
)
=
QueryValueEx
(
key
,
"ImagePath"
)
SetValueEx
(
key
,
"ImagePath"
,
None
,
2
,
f
'"{exe_path}"'
)
return
old_executable
def
copy_running_exe
(
dest
:
str
)
->
bool
:
"""
Startup helper:
- If the runnin executable is not
then copy it to dest (overwriting old dest if present),
- Otherwise do nothing.
Returns True if it performed changes, otherwise False.
"""
# Where are we actually running from?
current_exe
=
sys
.
executable
# Windows paths are case-insensitive -> compare with normcase
if
normcase
(
current_exe
)
==
normcase
(
dest
):
return
False
copy
(
current_exe
,
dest
)
return
True
def
servicePostUpdate
(
service_name
,
exe_path
):
logger
.
debug
(
"Running exe
%
s"
,
sys
.
executable
)
if
is_frozen_exe
()
and
copy_running_exe
(
exe_path
):
logger
.
debug
(
"The running agent copyed to
%
s"
,
exe_path
)
old_exe
=
update_service_binpath
(
service_name
,
exe_path
)
logger
.
debug
(
"
%
s service binpath updated
%
s ->
%
s"
,
service_name
,
old_exe
,
exe_path
)
return
True
return
False
def
getRegistryVal
(
reg_path
:
str
,
name
:
str
,
default
=
None
):
"""
Read HKLM
\\
<reg_path>
\\
<name> and return its value.
If key or value does not exist, return default.
Example:
getRegistryVal(
r"SYSTEM
\\
CurrentControlSet
\\
Services
\\
circle-agent",
"LogLevel",
"INFO"
)
"""
value
=
default
try
:
with
OpenKeyEx
(
HKEY_LOCAL_MACHINE
,
reg_path
,
0
,
KEY_READ
)
as
key
:
value
,
_
=
QueryValueEx
(
key
,
name
)
except
Exception
as
e
:
logging
.
debug
(
"Registry read failed
%
s
\\
%
s:
%
s"
,
reg_path
,
name
,
e
)
return
value
def
get_windows_version
():
if
sys
.
platform
!=
"win32"
:
return
None
ver
=
sys
.
getwindowsversion
()
major
=
ver
.
major
minor
=
ver
.
minor
build
=
ver
.
build
# Windows 7
if
major
==
6
and
minor
==
1
:
return
"Windows_7"
# Windows 8 / 8.1
if
major
==
6
and
minor
in
(
2
,
3
):
return
"Windows_8"
# Windows 10 / 11
if
major
==
10
:
# Windows 11 starts at build 22000
if
build
>=
22000
:
return
"Windows_11"
else
:
return
"Windows_10"
return
f
"Windows_{major}_{minor}_{build})"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment