Add db class and begin using it
This commit is contained in:
parent
4ae27db51e
commit
ce43f210ba
6 changed files with 148 additions and 18 deletions
7
app.py
7
app.py
|
@ -49,11 +49,14 @@ def index():
|
|||
user = latosa.login_user(username, password)
|
||||
if user:
|
||||
login_user(user)
|
||||
session = unquote(b64decode(form.session.data))
|
||||
client_id = form.client_id.data
|
||||
session = form.session.data
|
||||
if session:
|
||||
session = unquote(b64decode(session))
|
||||
if session and client_id:
|
||||
session_auth = Auth()
|
||||
session_data = session_auth.decrypt(client_id, session)
|
||||
session_key = session_data['session_key']
|
||||
if 'redirect_to' in session_data:
|
||||
response_data = {
|
||||
'status': 'success',
|
||||
|
@ -63,7 +66,7 @@ def index():
|
|||
'groups': user.groups,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
cyphertext = session_auth.encrypt(client_id, response_data)
|
||||
cyphertext = session_auth.encrypt(session_key, response_data)
|
||||
b64 = quote(b64encode(cyphertext).decode('utf-8'))
|
||||
url = session_data['redirect_to']
|
||||
redirect_to = f'{url}?session={b64}'
|
||||
|
|
21
auth.py
21
auth.py
|
@ -2,16 +2,29 @@ import json
|
|||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from db import DB
|
||||
|
||||
|
||||
class Auth:
|
||||
|
||||
def __init__(self):
|
||||
self.db = DB()
|
||||
|
||||
def get_key(self, client_id):
|
||||
return Fernet.generate_key()
|
||||
bind_params = {'client_id': client_id}
|
||||
return self.db.execute(
|
||||
'SELECT client_secret FROM auth WHERE client_id = :client_id',
|
||||
bind_params)[0]
|
||||
|
||||
def set_key(self, client_id, client_secret):
|
||||
bind_params = {'client_id': client_id, 'client_secret': client_secret}
|
||||
self.db.execute(
|
||||
'INSERT INTO auth (client_id, client_secret) VALUES (:client_id, :client_secret)',
|
||||
bind_params)
|
||||
|
||||
def decrypt(self, cyphertext, client_id) -> dict:
|
||||
return json.loads(
|
||||
Fernet(self.get_key(client_id)).decrypt(cyphertext).decode())
|
||||
|
||||
def encrypt(self, data, client_id) -> bytes:
|
||||
return Fernet(self.get_key(client_id)).encrypt(
|
||||
json.dumps(data).encode('utf-8'))
|
||||
def encrypt(self, session_key, data) -> bytes:
|
||||
return Fernet(session_key).encrypt(json.dumps(data).encode('utf-8'))
|
||||
|
|
60
db.py
Normal file
60
db.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from os import mkdir
|
||||
from os.path import isdir
|
||||
|
||||
from sqlalchemy import (BigInteger, Boolean, Column, Integer, MetaData, String,
|
||||
Table, create_engine, inspect)
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy_utils import create_database, database_exists
|
||||
|
||||
|
||||
class DB:
|
||||
|
||||
def __init__(self, db_url=None):
|
||||
if db_url is None:
|
||||
db_path = '/app/db'
|
||||
if not isdir(db_path):
|
||||
mkdir(db_path)
|
||||
db_url = f'sqlite:///{db_path}/latosa.db'
|
||||
db_engine = db_url.split('://')[0]
|
||||
if not database_exists(db_url):
|
||||
create_database(db_url)
|
||||
self.engine = create_engine(db_url, echo=True)
|
||||
self.inspector: Inspector = inspect(self.engine)
|
||||
self.tables = {}
|
||||
tables = [
|
||||
Table(
|
||||
'auth', MetaData(),
|
||||
Column('id',
|
||||
BigInteger().with_variant(Integer, db_engine),
|
||||
primary_key=True), Column('client_id', String),
|
||||
Column('client_secret', String)),
|
||||
Table(
|
||||
'users', MetaData(),
|
||||
Column('id',
|
||||
BigInteger().with_variant(Integer, db_engine),
|
||||
primary_key=True), Column('uid', String),
|
||||
Column('email', String), Column('groups', String),
|
||||
Column('password_hash', String), Column('salt', String),
|
||||
Column('is_active', Boolean), Column('is_admin', Boolean),
|
||||
Column('is_anonymous', Boolean),
|
||||
Column('display_name', String))
|
||||
]
|
||||
if not database_exists(db_url):
|
||||
create_database(db_url)
|
||||
for table in tables:
|
||||
self.tables[table.name] = table
|
||||
if not self.inspector.has_table(table.name):
|
||||
table.metadata.create_all(self.engine)
|
||||
|
||||
def connect(self):
|
||||
return self.engine.connect()
|
||||
|
||||
def execute(self, s: str, bind_params: dict):
|
||||
conn = self.connect()
|
||||
res = conn.execute(text(s), bind_params)
|
||||
conn.commit()
|
||||
return res
|
||||
|
||||
def get_table(self, table_name):
|
||||
return self.tables[table_name]
|
|
@ -1,4 +1,5 @@
|
|||
import glob
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
|
@ -11,7 +12,8 @@ class LaToSa:
|
|||
|
||||
def __init__(self, app: Flask):
|
||||
self.users = [
|
||||
User(app, 'micke', 'Micke Nordin', 'hej@mic.ke', 'S3cr3t!')
|
||||
User(app, 'micke', 'Micke Nordin', 'hej@mic.ke', 'S3cr3t!',
|
||||
secrets.token_hex(32))
|
||||
]
|
||||
|
||||
def get_users(self):
|
||||
|
|
|
@ -4,3 +4,5 @@ flask-login==0.6.3
|
|||
flask-wtf==1.2.1
|
||||
flask==3.0.2
|
||||
pyyaml==6.0.1
|
||||
SQLAlchemy==2.0.27
|
||||
SQLAlchemy-Utils==0.41.1
|
||||
|
|
72
user.py
72
user.py
|
@ -1,7 +1,23 @@
|
|||
import os
|
||||
|
||||
from flask import Flask
|
||||
from flask_bcrypt import Bcrypt
|
||||
|
||||
from db import DB
|
||||
|
||||
|
||||
class User:
|
||||
def __init__(self, app: Flask, uid: str, display_name: str, email:str, password: str, admin: bool = False, groups: list[str] = []):
|
||||
|
||||
def __init__(self,
|
||||
app: Flask,
|
||||
uid: str,
|
||||
display_name: str,
|
||||
email: str,
|
||||
password: str,
|
||||
salt: str,
|
||||
admin: bool = False,
|
||||
groups: list[str] = []):
|
||||
self.db = DB()
|
||||
self.display_name = display_name
|
||||
self.email = email
|
||||
self.groups = groups
|
||||
|
@ -11,11 +27,14 @@ class User:
|
|||
self.is_authenticated = False
|
||||
self.uid = uid
|
||||
self.bcrypt = Bcrypt(app)
|
||||
self.salt = self.get_salt()
|
||||
self.password_hash = self.bcrypt.generate_password_hash(password + self.salt).decode('utf-8')
|
||||
self.salt = salt
|
||||
self.password_hash = self.bcrypt.generate_password_hash(
|
||||
password + self.salt).decode('utf-8')
|
||||
self.commit()
|
||||
|
||||
def check_password(self, password: str):
|
||||
return self.bcrypt.check_password_hash(self.password_hash, password + self.salt)
|
||||
return self.bcrypt.check_password_hash(self.password_hash,
|
||||
password + self.salt)
|
||||
|
||||
def get_id(self):
|
||||
return self.uid
|
||||
|
@ -26,24 +45,55 @@ class User:
|
|||
def get_email(self):
|
||||
return self.email
|
||||
|
||||
def get_salt(self):
|
||||
return "salt"
|
||||
def get_groups(self):
|
||||
return ','.join(self.groups)
|
||||
|
||||
def set_active(self, active: bool):
|
||||
self.is_active = active
|
||||
self.commit()
|
||||
|
||||
def set_authenticated(self, authenticated: bool):
|
||||
self.is_authenticated = authenticated
|
||||
|
||||
self.commit()
|
||||
|
||||
def set_anonymous(self, anonymous: bool):
|
||||
self.is_anonymous = anonymous
|
||||
|
||||
self.commit()
|
||||
|
||||
def set_admin(self, admin: bool):
|
||||
self.is_admin = admin
|
||||
self.commit()
|
||||
|
||||
def set_email(self, email: str):
|
||||
self.email = email
|
||||
|
||||
def set_password(self, password: str):
|
||||
self.password_hash = self.bcrypt.generate_password_hash(password + self.salt).decode('utf-8')
|
||||
self.commit()
|
||||
|
||||
def set_password(self, password: str):
|
||||
self.password_hash = self.bcrypt.generate_password_hash(
|
||||
password + self.salt).decode('utf-8')
|
||||
self.commit()
|
||||
|
||||
def commit(self):
|
||||
bind_params = {
|
||||
'uid': self.uid,
|
||||
'display_name': self.display_name,
|
||||
'is_active': self.is_active,
|
||||
'is_anonymous': self.is_anonymous,
|
||||
'is_admin': self.is_admin,
|
||||
'email': self.email,
|
||||
'password_hash': self.password_hash,
|
||||
'salt': self.salt,
|
||||
'groups': self.get_groups(),
|
||||
}
|
||||
statement = "INSERT OR REPLACE INTO users (uid, display_name, is_active, is_anonymous, is_admin, email, password_hash, salt, groups) VALUES(:uid, :display_name, :is_active, :is_anonymous, :is_admin, :email, :password_hash, :salt, :groups)"
|
||||
self.db.execute(statement, bind_params)
|
||||
|
||||
@staticmethod
|
||||
def users_from_db() -> list:
|
||||
db = DB()
|
||||
statement = "SELECT * FROM users"
|
||||
result = db.execute(statement, {})
|
||||
users = []
|
||||
for row in result:
|
||||
users.append(User(**row))
|
||||
return users
|
||||
|
|
Loading…
Add table
Reference in a new issue