Commit e2461572 by Karsa Zoltán István

refactoring

parent adb1f3cd
from fastapi.responses import ORJSONResponse
from sredis.sredis import get_datacenter_token, rr_get
from fastapi import HTTPException
import requests
import json
def proxy_datacenters(serverpath: str, username, method="GET", balancer_fun = rr_get):
server = balancer_fun()
token = get_datacenter_token(username, server)
url=f"{server}/{serverpath}"
t_resp = requests.request(
method=method,
url=url,
allow_redirects=False, verify=False,
headers={
'Authorization': token
}
)
if t_resp.status_code / 100 != 2:
raise HTTPException(status_code=t_resp.status_code, detail="Remote server error")
response = ORJSONResponse(
json.loads(t_resp.content),
status_code=t_resp.status_code
)
return response
\ No newline at end of file
...@@ -2,10 +2,10 @@ from fastapi import Request, HTTPException, Depends ...@@ -2,10 +2,10 @@ from fastapi import Request, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from .auth import decodeJWT from .auth import decodeJWT
from sredis.models import PUser
class JWTBearer(HTTPBearer): class JWTBearer(HTTPBearer):
username: str = None
def __init__(self, auto_error: bool = True): def __init__(self, auto_error: bool = True):
super(JWTBearer, self).__init__(auto_error=auto_error) super(JWTBearer, self).__init__(auto_error=auto_error)
...@@ -18,7 +18,6 @@ class JWTBearer(HTTPBearer): ...@@ -18,7 +18,6 @@ class JWTBearer(HTTPBearer):
cred = self.verify_jwt(credentials.credentials) cred = self.verify_jwt(credentials.credentials)
if not cred: if not cred:
raise HTTPException(status_code=403, detail="Invalid token or expired token.") raise HTTPException(status_code=403, detail="Invalid token or expired token.")
self.username = cred['user_id']
return credentials.credentials return credentials.credentials
else: else:
raise HTTPException(status_code=403, detail="Invalid authorization code.") raise HTTPException(status_code=403, detail="Invalid authorization code.")
...@@ -37,3 +36,10 @@ class JWTBearer(HTTPBearer): ...@@ -37,3 +36,10 @@ class JWTBearer(HTTPBearer):
async def get_current_user(token: str = Depends(JWTBearer())) -> str: async def get_current_user(token: str = Depends(JWTBearer())) -> str:
payload = decodeJWT(token) payload = decodeJWT(token)
return payload['user_id'] return payload['user_id']
async def admin_user(token: str = Depends(JWTBearer())) -> str:
payload = decodeJWT(token)
user = PUser.find(PUser.username == payload['user_id']).all()[0]
if not user.admin:
raise HTTPException(status_code=401, detail="you can not have access")
return payload['user_id']
\ No newline at end of file
from fastapi import FastAPI, Response, Body, Depends from fastapi import FastAPI, Response, Body, Depends
from fastapi.responses import ORJSONResponse from typing import List
import json from balancer.util import proxy_datacenters
from fastapi import HTTPException
from typing import Union, List
import requests
from sredis.sredis import *
from sredis.models import * from sredis.models import *
from sredis.sredis import check_user, create_puser, add_datacenter, set_token
import logging import logging
from core.models import User, DataCenter, Token import requests
from core.models import User, DataCenter, Token, UserLoginSchema
from core.auth import signJWT from core.auth import signJWT
from core.bearer import get_current_user from core.bearer import get_current_user, admin_user
from redis_om import Migrator
logging.config.fileConfig('logging.conf', disable_existing_loggers=False) logging.config.fileConfig('logging.conf', disable_existing_loggers=False)
# get root logger # get root logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()
Migrator().run()
app = FastAPI() app = FastAPI()
add_datacenter("https://kappa1.fured.cloud.bme.hu") add_datacenter("https://kappa1.fured.cloud.bme.hu")
...@@ -36,45 +36,24 @@ async def user_login(user: UserLoginSchema = Body(...)): ...@@ -36,45 +36,24 @@ async def user_login(user: UserLoginSchema = Body(...)):
"error": "Wrong login details!" "error": "Wrong login details!"
} }
def _proxy_datacenters(serverpath: str, username, method="GET", balancer_fun = rr_get):
server = balancer_fun()
token = get_datacenter_token(username, server)
url=f"{server}/{serverpath}"
logger.debug("Req: " + url)
t_resp = requests.request(
method=method,
url=url,
allow_redirects=False, verify=False,
headers={
'Authorization': token
}
)
if t_resp.status_code / 100 != 2:
raise HTTPException(status_code=t_resp.status_code, detail="Remote server error")
response = ORJSONResponse(
json.loads(t_resp.content),
status_code=t_resp.status_code
)
return response
@app.get("/lb/{server_path:path}") @app.get("/lb/{server_path:path}")
def proxy( def proxy(
server_path: str = "/", server_path: str = "/",
username = Depends(get_current_user) username = Depends(get_current_user)
): ):
return _proxy_datacenters(server_path, username) return proxy_datacenters(server_path, username)
@app.post("/lb/{server_path:path}") @app.post("/lb/{server_path:path}")
def proxy( def proxy(
server_path: str = "/", server_path: str = "/",
username = Depends(get_current_user) username = Depends(get_current_user)
): ):
return _proxy_datacenters(server_path, username, method="POST") return proxy_datacenters(server_path, username, method="POST")
@app.post("/add_datacenter/") @app.post("/add_datacenter/")
def create_datacenter( def create_datacenter(
dc: DataCenter = None, dc: DataCenter = None,
username = Depends(get_current_user) username = Depends(admin_user)
): ):
add_datacenter(dc.name) add_datacenter(dc.name)
return Response(status_code=201) return Response(status_code=201)
......
...@@ -5,3 +5,4 @@ class PUser(HashModel): ...@@ -5,3 +5,4 @@ class PUser(HashModel):
username: str = Field(index=True) username: str = Field(index=True)
email: EmailStr email: EmailStr
password: str password: str
admin: bool = False
...@@ -42,7 +42,7 @@ def check_user(data: UserLoginSchema): ...@@ -42,7 +42,7 @@ def check_user(data: UserLoginSchema):
return False return False
def create_puser(user: User): def create_puser(user: User):
s = PUser.find(PUser.username == 'karsa').all() s = PUser.find(PUser.username == user.username).all()
if s: if s:
raise HTTPException(status_code=403, detail="User already exists") raise HTTPException(status_code=403, detail="User already exists")
user = PUser( user = PUser(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment