From ce43f210ba2d696c7b59005df746391a999d5e4a Mon Sep 17 00:00:00 2001 From: Micke Nordin Date: Sun, 3 Mar 2024 11:28:05 +0100 Subject: [PATCH] Add db class and begin using it --- app.py | 7 +++-- auth.py | 21 ++++++++++++--- db.py | 60 +++++++++++++++++++++++++++++++++++++++++ latosa.py | 4 ++- requirements.txt | 2 ++ user.py | 70 +++++++++++++++++++++++++++++++++++++++++------- 6 files changed, 147 insertions(+), 17 deletions(-) create mode 100644 db.py diff --git a/app.py b/app.py index 29ca959..a7d45f3 100644 --- a/app.py +++ b/app.py @@ -49,11 +49,14 @@ def index(): user = latosa.login_user(username, password) if user: login_user(user) - session = unquote(b64decode(form.session.data)) client_id = form.client_id.data + session = form.session.data + if session: + session = unquote(b64decode(session)) if session and client_id: session_auth = Auth() session_data = session_auth.decrypt(client_id, session) + session_key = session_data['session_key'] if 'redirect_to' in session_data: response_data = { 'status': 'success', @@ -63,7 +66,7 @@ def index(): 'groups': user.groups, 'timestamp': time.time() } - cyphertext = session_auth.encrypt(client_id, response_data) + cyphertext = session_auth.encrypt(session_key, response_data) b64 = quote(b64encode(cyphertext).decode('utf-8')) url = session_data['redirect_to'] redirect_to = f'{url}?session={b64}' diff --git a/auth.py b/auth.py index edd9459..8923329 100644 --- a/auth.py +++ b/auth.py @@ -2,16 +2,29 @@ import json from cryptography.fernet import Fernet +from db import DB + class Auth: + def __init__(self): + self.db = DB() + def get_key(self, client_id): - return Fernet.generate_key() + bind_params = {'client_id': client_id} + return self.db.execute( + 'SELECT client_secret FROM auth WHERE client_id = :client_id', + bind_params)[0] + + def set_key(self, client_id, client_secret): + bind_params = {'client_id': client_id, 'client_secret': client_secret} + self.db.execute( + 'INSERT INTO auth (client_id, client_secret) VALUES (:client_id, :client_secret)', + bind_params) def decrypt(self, cyphertext, client_id) -> dict: return json.loads( Fernet(self.get_key(client_id)).decrypt(cyphertext).decode()) - def encrypt(self, data, client_id) -> bytes: - return Fernet(self.get_key(client_id)).encrypt( - json.dumps(data).encode('utf-8')) + def encrypt(self, session_key, data) -> bytes: + return Fernet(session_key).encrypt(json.dumps(data).encode('utf-8')) diff --git a/db.py b/db.py new file mode 100644 index 0000000..0fcaefa --- /dev/null +++ b/db.py @@ -0,0 +1,60 @@ +from os import mkdir +from os.path import isdir + +from sqlalchemy import (BigInteger, Boolean, Column, Integer, MetaData, String, + Table, create_engine, inspect) +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.sql import text +from sqlalchemy_utils import create_database, database_exists + + +class DB: + + def __init__(self, db_url=None): + if db_url is None: + db_path = '/app/db' + if not isdir(db_path): + mkdir(db_path) + db_url = f'sqlite:///{db_path}/latosa.db' + db_engine = db_url.split('://')[0] + if not database_exists(db_url): + create_database(db_url) + self.engine = create_engine(db_url, echo=True) + self.inspector: Inspector = inspect(self.engine) + self.tables = {} + tables = [ + Table( + 'auth', MetaData(), + Column('id', + BigInteger().with_variant(Integer, db_engine), + primary_key=True), Column('client_id', String), + Column('client_secret', String)), + Table( + 'users', MetaData(), + Column('id', + BigInteger().with_variant(Integer, db_engine), + primary_key=True), Column('uid', String), + Column('email', String), Column('groups', String), + Column('password_hash', String), Column('salt', String), + Column('is_active', Boolean), Column('is_admin', Boolean), + Column('is_anonymous', Boolean), + Column('display_name', String)) + ] + if not database_exists(db_url): + create_database(db_url) + for table in tables: + self.tables[table.name] = table + if not self.inspector.has_table(table.name): + table.metadata.create_all(self.engine) + + def connect(self): + return self.engine.connect() + + def execute(self, s: str, bind_params: dict): + conn = self.connect() + res = conn.execute(text(s), bind_params) + conn.commit() + return res + + def get_table(self, table_name): + return self.tables[table_name] diff --git a/latosa.py b/latosa.py index b74d4ae..5482e23 100644 --- a/latosa.py +++ b/latosa.py @@ -1,4 +1,5 @@ import glob +import secrets import sys import yaml @@ -11,7 +12,8 @@ class LaToSa: def __init__(self, app: Flask): self.users = [ - User(app, 'micke', 'Micke Nordin', 'hej@mic.ke', 'S3cr3t!') + User(app, 'micke', 'Micke Nordin', 'hej@mic.ke', 'S3cr3t!', + secrets.token_hex(32)) ] def get_users(self): diff --git a/requirements.txt b/requirements.txt index d45f183..476adc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ flask-login==0.6.3 flask-wtf==1.2.1 flask==3.0.2 pyyaml==6.0.1 +SQLAlchemy==2.0.27 +SQLAlchemy-Utils==0.41.1 diff --git a/user.py b/user.py index b891f80..942776a 100644 --- a/user.py +++ b/user.py @@ -1,7 +1,23 @@ +import os + from flask import Flask from flask_bcrypt import Bcrypt + +from db import DB + + class User: - def __init__(self, app: Flask, uid: str, display_name: str, email:str, password: str, admin: bool = False, groups: list[str] = []): + + def __init__(self, + app: Flask, + uid: str, + display_name: str, + email: str, + password: str, + salt: str, + admin: bool = False, + groups: list[str] = []): + self.db = DB() self.display_name = display_name self.email = email self.groups = groups @@ -11,11 +27,14 @@ class User: self.is_authenticated = False self.uid = uid self.bcrypt = Bcrypt(app) - self.salt = self.get_salt() - self.password_hash = self.bcrypt.generate_password_hash(password + self.salt).decode('utf-8') + self.salt = salt + self.password_hash = self.bcrypt.generate_password_hash( + password + self.salt).decode('utf-8') + self.commit() def check_password(self, password: str): - return self.bcrypt.check_password_hash(self.password_hash, password + self.salt) + return self.bcrypt.check_password_hash(self.password_hash, + password + self.salt) def get_id(self): return self.uid @@ -26,24 +45,55 @@ class User: def get_email(self): return self.email - def get_salt(self): - return "salt" + def get_groups(self): + return ','.join(self.groups) def set_active(self, active: bool): self.is_active = active + self.commit() def set_authenticated(self, authenticated: bool): self.is_authenticated = authenticated - + self.commit() + def set_anonymous(self, anonymous: bool): self.is_anonymous = anonymous - + self.commit() + def set_admin(self, admin: bool): self.is_admin = admin + self.commit() def set_email(self, email: str): self.email = email - + self.commit() + def set_password(self, password: str): - self.password_hash = self.bcrypt.generate_password_hash(password + self.salt).decode('utf-8') + self.password_hash = self.bcrypt.generate_password_hash( + password + self.salt).decode('utf-8') + self.commit() + + def commit(self): + bind_params = { + 'uid': self.uid, + 'display_name': self.display_name, + 'is_active': self.is_active, + 'is_anonymous': self.is_anonymous, + 'is_admin': self.is_admin, + 'email': self.email, + 'password_hash': self.password_hash, + 'salt': self.salt, + 'groups': self.get_groups(), + } + statement = "INSERT OR REPLACE INTO users (uid, display_name, is_active, is_anonymous, is_admin, email, password_hash, salt, groups) VALUES(:uid, :display_name, :is_active, :is_anonymous, :is_admin, :email, :password_hash, :salt, :groups)" + self.db.execute(statement, bind_params) + @staticmethod + def users_from_db() -> list: + db = DB() + statement = "SELECT * FROM users" + result = db.execute(statement, {}) + users = [] + for row in result: + users.append(User(**row)) + return users