Commit d8264cd8 by Karsa Zoltán István

linter

parent e2461572
# Python bytecode: # Python bytecode:
*.py[co] *.py[co]
.tokens .tokens
.env
.ruff_cache/
# Packaging files: # Packaging files:
*.egg* *.egg*
......
...@@ -4,22 +4,23 @@ from fastapi import HTTPException ...@@ -4,22 +4,23 @@ from fastapi import HTTPException
import requests import requests
import json import json
def proxy_datacenters(serverpath: str, username, method="GET", balancer_fun = rr_get):
def proxy_datacenters(serverpath: str, username, method="GET", balancer_fun=rr_get):
server = balancer_fun() server = balancer_fun()
token = get_datacenter_token(username, server) token = get_datacenter_token(username, server)
url=f"{server}/{serverpath}" url = f"{server}/{serverpath}"
t_resp = requests.request( t_resp = requests.request(
method=method, method=method,
url=url, url=url,
allow_redirects=False, verify=False, allow_redirects=False,
headers={ verify=False,
'Authorization': token headers={"Authorization": token},
}
) )
if t_resp.status_code / 100 != 2: if t_resp.status_code / 100 != 2:
raise HTTPException(status_code=t_resp.status_code, detail="Remote server error") raise HTTPException(
status_code=t_resp.status_code, detail="Remote server error"
)
response = ORJSONResponse( response = ORJSONResponse(
json.loads(t_resp.content), json.loads(t_resp.content), status_code=t_resp.status_code
status_code=t_resp.status_code
) )
return response return response
\ No newline at end of file
...@@ -3,6 +3,7 @@ from typing import Dict ...@@ -3,6 +3,7 @@ from typing import Dict
import jwt import jwt
from passlib.hash import pbkdf2_sha256 from passlib.hash import pbkdf2_sha256
from decouple import config from decouple import config
from fastapi import HTTPException
JWT_SECRET = config("secret") JWT_SECRET = config("secret")
...@@ -10,25 +11,23 @@ JWT_ALGORITHM = config("algorithm") ...@@ -10,25 +11,23 @@ JWT_ALGORITHM = config("algorithm")
def token_response(token: str): def token_response(token: str):
return { return {"access_token": token}
"access_token": token
}
def signJWT(user_id: str) -> Dict[str, str]: def signJWT(user_id: str) -> Dict[str, str]:
payload = { payload = {"user_id": user_id, "expires": time.time() + 60000}
"user_id": user_id,
"expires": time.time() + 60000
}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token_response(token) return token_response(token)
def decodeJWT(token: str) -> dict: def decodeJWT(token: str) -> dict:
try: try:
decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return decoded_token if decoded_token["expires"] >= time.time() else None return decoded_token if decoded_token["expires"] >= time.time() else None
except: except jwt.ExpiredSignatureError:
return {} raise HTTPException(status_code=501, detail="JWT token decode error")
def hash_pass(password: str) -> str: def hash_pass(password: str) -> str:
return pbkdf2_sha256.hash(password) return pbkdf2_sha256.hash(password)
\ No newline at end of file
...@@ -6,18 +6,23 @@ from sredis.models import PUser ...@@ -6,18 +6,23 @@ from sredis.models import PUser
class JWTBearer(HTTPBearer): class JWTBearer(HTTPBearer):
def __init__(self, auto_error: bool = True): def __init__(self, auto_error: bool = True):
super(JWTBearer, self).__init__(auto_error=auto_error) super(JWTBearer, self).__init__(auto_error=auto_error)
async def __call__(self, request: Request): async def __call__(self, request: Request):
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request) credentials: HTTPAuthorizationCredentials = await super(
JWTBearer, self
).__call__(request)
if credentials: if credentials:
if not credentials.scheme == "Bearer": if not credentials.scheme == "Bearer":
raise HTTPException(status_code=403, detail="Invalid authentication scheme.") raise HTTPException(
status_code=403, detail="Invalid authentication scheme."
)
cred = self.verify_jwt(credentials.credentials) cred = self.verify_jwt(credentials.credentials)
if not cred: if not cred:
raise HTTPException(status_code=403, detail="Invalid token or expired token.") raise HTTPException(
status_code=403, detail="Invalid token or expired token."
)
return credentials.credentials return credentials.credentials
else: else:
raise HTTPException(status_code=403, detail="Invalid authorization code.") raise HTTPException(status_code=403, detail="Invalid authorization code.")
...@@ -25,21 +30,20 @@ class JWTBearer(HTTPBearer): ...@@ -25,21 +30,20 @@ class JWTBearer(HTTPBearer):
def verify_jwt(self, jwtoken: str) -> bool: def verify_jwt(self, jwtoken: str) -> bool:
isTokenValid: bool = False isTokenValid: bool = False
try: payload = decodeJWT(jwtoken)
payload = decodeJWT(jwtoken)
except:
payload = None
if payload: if payload:
return payload return payload
return isTokenValid return isTokenValid
async def get_current_user(token: str = Depends(JWTBearer())) -> str: async def get_current_user(token: str = Depends(JWTBearer())) -> str:
payload = decodeJWT(token) payload = decodeJWT(token)
return payload['user_id'] return payload["user_id"]
async def admin_user(token: str = Depends(JWTBearer())) -> str: async def admin_user(token: str = Depends(JWTBearer())) -> str:
payload = decodeJWT(token) payload = decodeJWT(token)
user = PUser.find(PUser.username == payload['user_id']).all()[0] user = PUser.find(PUser.username == payload["user_id"]).all()[0]
if not user.admin: if not user.admin:
raise HTTPException(status_code=401, detail="you can not have access") raise HTTPException(status_code=401, detail="you can not have access")
return payload['user_id'] return payload["user_id"]
\ No newline at end of file
from typing import Union
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
class User(BaseModel): class User(BaseModel):
username: str username: str
email: EmailStr email: EmailStr
password: str password: str
class DataCenter(BaseModel): class DataCenter(BaseModel):
name: str name: str
class Token(BaseModel): class Token(BaseModel):
datacenter: str datacenter: str
token: str token: str
class UserLoginSchema(BaseModel): class UserLoginSchema(BaseModel):
username: str username: str
password: str password: str
class Config: class Config:
schema_extra = { schema_extra = {"example": {"username": "user", "password": "weakpassword"}}
"example": {
"username": "user",
"password": "weakpassword"
}
}
\ No newline at end of file
from fastapi import FastAPI, Response, Body, Depends from fastapi import FastAPI, Response, Body, Depends
from typing import List from typing import List
from balancer.util import proxy_datacenters from balancer.util import proxy_datacenters
from sredis.models import *
from sredis.sredis import check_user, create_puser, add_datacenter, set_token from sredis.sredis import check_user, create_puser, add_datacenter, set_token
import logging import logging
import requests import requests
...@@ -10,11 +9,11 @@ from core.auth import signJWT ...@@ -10,11 +9,11 @@ from core.auth import signJWT
from core.bearer import get_current_user, admin_user from core.bearer import get_current_user, admin_user
from redis_om import Migrator from redis_om import Migrator
logging.config.fileConfig('logging.conf', disable_existing_loggers=False) logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
# get root logger # get root logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()
Migrator().run() Migrator().run()
app = FastAPI() app = FastAPI()
...@@ -28,41 +27,32 @@ async def create_user(user: User = Body(...)): ...@@ -28,41 +27,32 @@ async def create_user(user: User = Body(...)):
create_puser(user) create_puser(user)
return signJWT(user.username) return signJWT(user.username)
@app.post("/user/login", tags=["user"]) @app.post("/user/login", tags=["user"])
async def user_login(user: UserLoginSchema = Body(...)): async def user_login(user: UserLoginSchema = Body(...)):
if check_user(user): if check_user(user):
return signJWT(user.username) return signJWT(user.username)
return { return {"error": "Wrong login details!"}
"error": "Wrong login details!"
}
@app.get("/lb/{server_path:path}") @app.get("/lb/{server_path:path}")
def proxy( def proxy_get(server_path: str = "/", username=Depends(get_current_user)):
server_path: str = "/",
username = Depends(get_current_user)
):
return proxy_datacenters(server_path, username) return proxy_datacenters(server_path, username)
@app.post("/lb/{server_path:path}") @app.post("/lb/{server_path:path}")
def proxy( def proxy_post(server_path: str = "/", username=Depends(get_current_user)):
server_path: str = "/",
username = Depends(get_current_user)
):
return proxy_datacenters(server_path, username, method="POST") return proxy_datacenters(server_path, username, method="POST")
@app.post("/add_datacenter/") @app.post("/add_datacenter/")
def create_datacenter( def create_datacenter(dc: DataCenter = None, username=Depends(admin_user)):
dc: DataCenter = None,
username = Depends(admin_user)
):
add_datacenter(dc.name) add_datacenter(dc.name)
return Response(status_code=201) return Response(status_code=201)
@app.post("/set_tokens/") @app.post("/set_tokens/")
def set_tokens( def set_tokens(tokens: List[Token] = None, username=Depends(get_current_user)):
tokens: List[Token] = None,
username = Depends(get_current_user)
):
for token in tokens: for token in tokens:
set_token(username, str(token.datacenter), str(token.token)) set_token(username, str(token.datacenter), str(token.token))
return tokens return tokens
...@@ -18,13 +18,15 @@ python-decouple = "^3.8" ...@@ -18,13 +18,15 @@ python-decouple = "^3.8"
redis-om = "^0.1.2" redis-om = "^0.1.2"
pydantic = {extras = ["email"], version = "^1.10.6"} pydantic = {extras = ["email"], version = "^1.10.6"}
passlib = "^1.7.4" passlib = "^1.7.4"
black = "^23.1.0"
ruff = "^0.0.254"
[tool.poe.tasks.start] [tool.poe.tasks.start]
shell = "poetry run uvicorn main:app --reload --port 6973 --host 0.0.0.0" shell = "poetry run uvicorn main:app --reload --port 6973 --host 0.0.0.0"
help = "Start the microservice on port 6973" help = "Start the microservice on port 6973"
[tool.poe.tasks.lint] [tool.poe.tasks.lint]
shell = "poetry run black . && poetry run ruff --fix . && poetry run mypy backend" shell = "poetry run black . && poetry run ruff --fix . "
help = "Lint the most important parts of the microservice with black" help = "Lint the most important parts of the microservice with black"
[build-system] [build-system]
......
from redis_om import HashModel, Field from redis_om import HashModel, Field
from pydantic import EmailStr from pydantic import EmailStr
class PUser(HashModel): class PUser(HashModel):
username: str = Field(index=True) username: str = Field(index=True)
email: EmailStr email: EmailStr
password: str password: str
admin: bool = False admin: bool = False
...@@ -5,13 +5,13 @@ from passlib.hash import pbkdf2_sha256 ...@@ -5,13 +5,13 @@ from passlib.hash import pbkdf2_sha256
from core.models import User from core.models import User
from core.auth import hash_pass from core.auth import hash_pass
from fastapi import HTTPException from fastapi import HTTPException
from redis_om import Migrator
r = redis.Redis(host='localhost', port=6379, db=0, decode_responses=True) r = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
r.set("datacenters_cnt", 1) r.set("datacenters_cnt", 1)
all_keys = list(r.hgetall('datacenters_hash').keys()) all_keys = list(r.hgetall("datacenters_hash").keys())
if all_keys: if all_keys:
r.hdel('datacenters_hash', *all_keys) r.hdel("datacenters_hash", *all_keys)
def add_datacenter(datacenter: str): def add_datacenter(datacenter: str):
cnt = int(r.get("datacenters_cnt")) cnt = int(r.get("datacenters_cnt"))
...@@ -19,35 +19,38 @@ def add_datacenter(datacenter: str): ...@@ -19,35 +19,38 @@ def add_datacenter(datacenter: str):
r.incr("datacenters_cnt") r.incr("datacenters_cnt")
r.set("roundrobin_cnt", 1) r.set("roundrobin_cnt", 1)
def rr_get(): def rr_get():
cnt = int(r.get("datacenters_cnt")) cnt = int(r.get("datacenters_cnt"))
rr = int(r.get("roundrobin_cnt")) rr = int(r.get("roundrobin_cnt"))
if rr + 1 >= cnt: if rr + 1 >= cnt:
r.set("roundrobin_cnt", 1) r.set("roundrobin_cnt", 1)
else: else:
r.incr("roundrobin_cnt") r.incr("roundrobin_cnt")
return str(r.hget("datacenters_hash", rr)) return str(r.hget("datacenters_hash", rr))
def set_token(username: str, datacenter: str, token: str): def set_token(username: str, datacenter: str, token: str):
print(f"tokens:{username}" + datacenter) print(f"tokens:{username}" + datacenter)
r.hset(f"tokens:{username}", datacenter, token) r.hset(f"tokens:{username}", datacenter, token)
def get_datacenter_token(username: str, datacenter: str): def get_datacenter_token(username: str, datacenter: str):
return str(r.hget(f"tokens:{username}", datacenter)) return str(r.hget(f"tokens:{username}", datacenter))
def check_user(data: UserLoginSchema): def check_user(data: UserLoginSchema):
user = PUser.find(PUser.username == data.username).all() user = PUser.find(PUser.username == data.username).all()
if pbkdf2_sha256.verify(data.password, user[0].password): if pbkdf2_sha256.verify(data.password, user[0].password):
return user[0] return user[0]
return False return False
def create_puser(user: User): def create_puser(user: User):
s = PUser.find(PUser.username == user.username).all() s = PUser.find(PUser.username == user.username).all()
if s: if s:
raise HTTPException(status_code=403, detail="User already exists") raise HTTPException(status_code=403, detail="User already exists")
user = PUser( user = PUser(
username=user.username, username=user.username, email=user.email, password=hash_pass(user.password)
email=user.email,
password=hash_pass(user.password)
) )
user.save() user.save()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment