Add db class and begin using it
This commit is contained in:
parent
4ae27db51e
commit
ce43f210ba
6 changed files with 148 additions and 18 deletions
7
app.py
7
app.py
|
@ -49,11 +49,14 @@ def index():
|
||||||
user = latosa.login_user(username, password)
|
user = latosa.login_user(username, password)
|
||||||
if user:
|
if user:
|
||||||
login_user(user)
|
login_user(user)
|
||||||
session = unquote(b64decode(form.session.data))
|
|
||||||
client_id = form.client_id.data
|
client_id = form.client_id.data
|
||||||
|
session = form.session.data
|
||||||
|
if session:
|
||||||
|
session = unquote(b64decode(session))
|
||||||
if session and client_id:
|
if session and client_id:
|
||||||
session_auth = Auth()
|
session_auth = Auth()
|
||||||
session_data = session_auth.decrypt(client_id, session)
|
session_data = session_auth.decrypt(client_id, session)
|
||||||
|
session_key = session_data['session_key']
|
||||||
if 'redirect_to' in session_data:
|
if 'redirect_to' in session_data:
|
||||||
response_data = {
|
response_data = {
|
||||||
'status': 'success',
|
'status': 'success',
|
||||||
|
@ -63,7 +66,7 @@ def index():
|
||||||
'groups': user.groups,
|
'groups': user.groups,
|
||||||
'timestamp': time.time()
|
'timestamp': time.time()
|
||||||
}
|
}
|
||||||
cyphertext = session_auth.encrypt(client_id, response_data)
|
cyphertext = session_auth.encrypt(session_key, response_data)
|
||||||
b64 = quote(b64encode(cyphertext).decode('utf-8'))
|
b64 = quote(b64encode(cyphertext).decode('utf-8'))
|
||||||
url = session_data['redirect_to']
|
url = session_data['redirect_to']
|
||||||
redirect_to = f'{url}?session={b64}'
|
redirect_to = f'{url}?session={b64}'
|
||||||
|
|
21
auth.py
21
auth.py
|
@ -2,16 +2,29 @@ import json
|
||||||
|
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
from db import DB
|
||||||
|
|
||||||
|
|
||||||
class Auth:
|
class Auth:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.db = DB()
|
||||||
|
|
||||||
def get_key(self, client_id):
|
def get_key(self, client_id):
|
||||||
return Fernet.generate_key()
|
bind_params = {'client_id': client_id}
|
||||||
|
return self.db.execute(
|
||||||
|
'SELECT client_secret FROM auth WHERE client_id = :client_id',
|
||||||
|
bind_params)[0]
|
||||||
|
|
||||||
|
def set_key(self, client_id, client_secret):
|
||||||
|
bind_params = {'client_id': client_id, 'client_secret': client_secret}
|
||||||
|
self.db.execute(
|
||||||
|
'INSERT INTO auth (client_id, client_secret) VALUES (:client_id, :client_secret)',
|
||||||
|
bind_params)
|
||||||
|
|
||||||
def decrypt(self, cyphertext, client_id) -> dict:
|
def decrypt(self, cyphertext, client_id) -> dict:
|
||||||
return json.loads(
|
return json.loads(
|
||||||
Fernet(self.get_key(client_id)).decrypt(cyphertext).decode())
|
Fernet(self.get_key(client_id)).decrypt(cyphertext).decode())
|
||||||
|
|
||||||
def encrypt(self, data, client_id) -> bytes:
|
def encrypt(self, session_key, data) -> bytes:
|
||||||
return Fernet(self.get_key(client_id)).encrypt(
|
return Fernet(session_key).encrypt(json.dumps(data).encode('utf-8'))
|
||||||
json.dumps(data).encode('utf-8'))
|
|
||||||
|
|
60
db.py
Normal file
60
db.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
from os import mkdir
|
||||||
|
from os.path import isdir
|
||||||
|
|
||||||
|
from sqlalchemy import (BigInteger, Boolean, Column, Integer, MetaData, String,
|
||||||
|
Table, create_engine, inspect)
|
||||||
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
|
from sqlalchemy.sql import text
|
||||||
|
from sqlalchemy_utils import create_database, database_exists
|
||||||
|
|
||||||
|
|
||||||
|
class DB:
|
||||||
|
|
||||||
|
def __init__(self, db_url=None):
|
||||||
|
if db_url is None:
|
||||||
|
db_path = '/app/db'
|
||||||
|
if not isdir(db_path):
|
||||||
|
mkdir(db_path)
|
||||||
|
db_url = f'sqlite:///{db_path}/latosa.db'
|
||||||
|
db_engine = db_url.split('://')[0]
|
||||||
|
if not database_exists(db_url):
|
||||||
|
create_database(db_url)
|
||||||
|
self.engine = create_engine(db_url, echo=True)
|
||||||
|
self.inspector: Inspector = inspect(self.engine)
|
||||||
|
self.tables = {}
|
||||||
|
tables = [
|
||||||
|
Table(
|
||||||
|
'auth', MetaData(),
|
||||||
|
Column('id',
|
||||||
|
BigInteger().with_variant(Integer, db_engine),
|
||||||
|
primary_key=True), Column('client_id', String),
|
||||||
|
Column('client_secret', String)),
|
||||||
|
Table(
|
||||||
|
'users', MetaData(),
|
||||||
|
Column('id',
|
||||||
|
BigInteger().with_variant(Integer, db_engine),
|
||||||
|
primary_key=True), Column('uid', String),
|
||||||
|
Column('email', String), Column('groups', String),
|
||||||
|
Column('password_hash', String), Column('salt', String),
|
||||||
|
Column('is_active', Boolean), Column('is_admin', Boolean),
|
||||||
|
Column('is_anonymous', Boolean),
|
||||||
|
Column('display_name', String))
|
||||||
|
]
|
||||||
|
if not database_exists(db_url):
|
||||||
|
create_database(db_url)
|
||||||
|
for table in tables:
|
||||||
|
self.tables[table.name] = table
|
||||||
|
if not self.inspector.has_table(table.name):
|
||||||
|
table.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
return self.engine.connect()
|
||||||
|
|
||||||
|
def execute(self, s: str, bind_params: dict):
|
||||||
|
conn = self.connect()
|
||||||
|
res = conn.execute(text(s), bind_params)
|
||||||
|
conn.commit()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_table(self, table_name):
|
||||||
|
return self.tables[table_name]
|
|
@ -1,4 +1,5 @@
|
||||||
import glob
|
import glob
|
||||||
|
import secrets
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -11,7 +12,8 @@ class LaToSa:
|
||||||
|
|
||||||
def __init__(self, app: Flask):
|
def __init__(self, app: Flask):
|
||||||
self.users = [
|
self.users = [
|
||||||
User(app, 'micke', 'Micke Nordin', 'hej@mic.ke', 'S3cr3t!')
|
User(app, 'micke', 'Micke Nordin', 'hej@mic.ke', 'S3cr3t!',
|
||||||
|
secrets.token_hex(32))
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_users(self):
|
def get_users(self):
|
||||||
|
|
|
@ -4,3 +4,5 @@ flask-login==0.6.3
|
||||||
flask-wtf==1.2.1
|
flask-wtf==1.2.1
|
||||||
flask==3.0.2
|
flask==3.0.2
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
|
SQLAlchemy==2.0.27
|
||||||
|
SQLAlchemy-Utils==0.41.1
|
||||||
|
|
72
user.py
72
user.py
|
@ -1,7 +1,23 @@
|
||||||
|
import os
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from flask_bcrypt import Bcrypt
|
from flask_bcrypt import Bcrypt
|
||||||
|
|
||||||
|
from db import DB
|
||||||
|
|
||||||
|
|
||||||
class User:
|
class User:
|
||||||
def __init__(self, app: Flask, uid: str, display_name: str, email:str, password: str, admin: bool = False, groups: list[str] = []):
|
|
||||||
|
def __init__(self,
|
||||||
|
app: Flask,
|
||||||
|
uid: str,
|
||||||
|
display_name: str,
|
||||||
|
email: str,
|
||||||
|
password: str,
|
||||||
|
salt: str,
|
||||||
|
admin: bool = False,
|
||||||
|
groups: list[str] = []):
|
||||||
|
self.db = DB()
|
||||||
self.display_name = display_name
|
self.display_name = display_name
|
||||||
self.email = email
|
self.email = email
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
|
@ -11,11 +27,14 @@ class User:
|
||||||
self.is_authenticated = False
|
self.is_authenticated = False
|
||||||
self.uid = uid
|
self.uid = uid
|
||||||
self.bcrypt = Bcrypt(app)
|
self.bcrypt = Bcrypt(app)
|
||||||
self.salt = self.get_salt()
|
self.salt = salt
|
||||||
self.password_hash = self.bcrypt.generate_password_hash(password + self.salt).decode('utf-8')
|
self.password_hash = self.bcrypt.generate_password_hash(
|
||||||
|
password + self.salt).decode('utf-8')
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def check_password(self, password: str):
|
def check_password(self, password: str):
|
||||||
return self.bcrypt.check_password_hash(self.password_hash, password + self.salt)
|
return self.bcrypt.check_password_hash(self.password_hash,
|
||||||
|
password + self.salt)
|
||||||
|
|
||||||
def get_id(self):
|
def get_id(self):
|
||||||
return self.uid
|
return self.uid
|
||||||
|
@ -26,24 +45,55 @@ class User:
|
||||||
def get_email(self):
|
def get_email(self):
|
||||||
return self.email
|
return self.email
|
||||||
|
|
||||||
def get_salt(self):
|
def get_groups(self):
|
||||||
return "salt"
|
return ','.join(self.groups)
|
||||||
|
|
||||||
def set_active(self, active: bool):
|
def set_active(self, active: bool):
|
||||||
self.is_active = active
|
self.is_active = active
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def set_authenticated(self, authenticated: bool):
|
def set_authenticated(self, authenticated: bool):
|
||||||
self.is_authenticated = authenticated
|
self.is_authenticated = authenticated
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def set_anonymous(self, anonymous: bool):
|
def set_anonymous(self, anonymous: bool):
|
||||||
self.is_anonymous = anonymous
|
self.is_anonymous = anonymous
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def set_admin(self, admin: bool):
|
def set_admin(self, admin: bool):
|
||||||
self.is_admin = admin
|
self.is_admin = admin
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def set_email(self, email: str):
|
def set_email(self, email: str):
|
||||||
self.email = email
|
self.email = email
|
||||||
|
self.commit()
|
||||||
def set_password(self, password: str):
|
|
||||||
self.password_hash = self.bcrypt.generate_password_hash(password + self.salt).decode('utf-8')
|
|
||||||
|
|
||||||
|
def set_password(self, password: str):
|
||||||
|
self.password_hash = self.bcrypt.generate_password_hash(
|
||||||
|
password + self.salt).decode('utf-8')
|
||||||
|
self.commit()
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
bind_params = {
|
||||||
|
'uid': self.uid,
|
||||||
|
'display_name': self.display_name,
|
||||||
|
'is_active': self.is_active,
|
||||||
|
'is_anonymous': self.is_anonymous,
|
||||||
|
'is_admin': self.is_admin,
|
||||||
|
'email': self.email,
|
||||||
|
'password_hash': self.password_hash,
|
||||||
|
'salt': self.salt,
|
||||||
|
'groups': self.get_groups(),
|
||||||
|
}
|
||||||
|
statement = "INSERT OR REPLACE INTO users (uid, display_name, is_active, is_anonymous, is_admin, email, password_hash, salt, groups) VALUES(:uid, :display_name, :is_active, :is_anonymous, :is_admin, :email, :password_hash, :salt, :groups)"
|
||||||
|
self.db.execute(statement, bind_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def users_from_db() -> list:
|
||||||
|
db = DB()
|
||||||
|
statement = "SELECT * FROM users"
|
||||||
|
result = db.execute(statement, {})
|
||||||
|
users = []
|
||||||
|
for row in result:
|
||||||
|
users.append(User(**row))
|
||||||
|
return users
|
||||||
|
|
Loading…
Add table
Reference in a new issue