66 lines
1.5 KiB
Python
66 lines
1.5 KiB
Python
# Local
|
|
from all_paw_care.db.types.user import User
|
|
from all_paw_care.db.types.base import Base
|
|
from all_paw_care.db import DBEngine
|
|
from all_paw_care.db import DBSession
|
|
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import select
|
|
|
|
def ensure_tables():
|
|
Base.metadata.create_all(DBEngine)
|
|
|
|
def login(username: str):
|
|
if get_user(username):
|
|
return True
|
|
|
|
else:
|
|
return False
|
|
|
|
def add_user(username: str):
|
|
with DBSession() as session, session.begin():
|
|
try:
|
|
session.add(User(username=username))
|
|
|
|
except:
|
|
session.rollback()
|
|
|
|
finally:
|
|
session.commit()
|
|
|
|
return True
|
|
|
|
def get_users():
|
|
with DBSession() as session, session.begin():
|
|
users = list()
|
|
database_users = session.scalars(select(User).order_by(User.id)).all()
|
|
for database_user in database_users:
|
|
user = (database_user.id, database_user.username)
|
|
users.append(user)
|
|
|
|
return users
|
|
|
|
def get_user(username: str = None):
|
|
if username:
|
|
with DBSession() as session, session.begin():
|
|
user = session.scalars(
|
|
select(User).where(User.username == username)).all()
|
|
|
|
if len(user) == 1:
|
|
return (user[0].id, user[0].username)
|
|
|
|
elif len(user) == 0:
|
|
return None
|
|
|
|
else:
|
|
return None
|
|
|
|
|
|
|
|
def user_exists(username: str):
|
|
with DBSession() as session, session.begin():
|
|
users = session.execute(
|
|
select(User).where(User.username == username))
|
|
print(users)
|
|
|