From 57191c2c237d18aee5c828fa664d2299377bbb61 Mon Sep 17 00:00:00 2001 From: freamon Date: Sun, 27 Oct 2024 13:36:17 +0000 Subject: [PATCH] API: utilise DB exceptions for return errors --- app/api/alpha/utils/community.py | 4 ---- app/api/alpha/utils/post.py | 5 +---- app/api/alpha/utils/reply.py | 12 +++--------- app/api/alpha/views.py | 22 ++++++---------------- app/models.py | 7 ++----- app/shared/auth.py | 26 ++++++++++---------------- app/shared/community.py | 8 ++------ app/shared/post.py | 22 ++++++---------------- app/shared/reply.py | 22 ++++++---------------- app/utils.py | 28 +++++++++++----------------- 10 files changed, 47 insertions(+), 109 deletions(-) diff --git a/app/api/alpha/utils/community.py b/app/api/alpha/utils/community.py index 02c6dfaa..8222b377 100644 --- a/app/api/alpha/utils/community.py +++ b/app/api/alpha/utils/community.py @@ -17,15 +17,11 @@ def cached_community_list(type, user_id): else: communities = Community.query.filter_by(banned=False) - print(len(communities.all())) - if user_id is not None: blocked_instance_ids = blocked_instances(user_id) if blocked_instance_ids: communities = communities.filter(Community.instance_id.not_in(blocked_instance_ids)) - print(len(communities.all())) - return communities.all() diff --git a/app/api/alpha/utils/post.py b/app/api/alpha/utils/post.py index 05f241ec..0d9fcf32 100644 --- a/app/api/alpha/utils/post.py +++ b/app/api/alpha/utils/post.py @@ -98,10 +98,7 @@ def get_post(auth, data): user_id = authorise_api_user(auth) if auth else None post_json = post_view(post=id, variant=3, user_id=user_id) - if post_json: - return post_json - else: - raise Exception('post_not_found') + return post_json # would be in app/constants.py diff --git a/app/api/alpha/utils/reply.py b/app/api/alpha/utils/reply.py index fab300a0..436b29fb 100644 --- a/app/api/alpha/utils/reply.py +++ b/app/api/alpha/utils/reply.py @@ -131,9 +131,7 @@ def post_reply(auth, data): language_id = 2 # FIXME: use site language input = {'body': body, 'notify_author': True, 'language_id': language_id} - post = Post.query.get(post_id) - if not post: - raise Exception('parent_not_found') + post = Post.query.filter_by(id=post_id).one() user_id, reply = make_reply(input, post, parent_id, SRC_API, auth) @@ -153,12 +151,8 @@ def put_reply(auth, data): language_id = 2 # FIXME: use site language input = {'body': body, 'notify_author': True, 'language_id': language_id} - reply = PostReply.query.get(reply_id) - if not reply: - raise Exception('reply_not_found') - post = Post.query.get(reply.post_id) - if not post: - raise Exception('post_not_found') + reply = PostReply.query.filter_by(id=reply_id).one() + post = Post.query.filter_by(id=reply.post_id).one() user_id, reply = edit_reply(input, reply, post, SRC_API, auth) diff --git a/app/api/alpha/views.py b/app/api/alpha/views.py index 57de04fa..107b15d7 100644 --- a/app/api/alpha/views.py +++ b/app/api/alpha/views.py @@ -12,9 +12,7 @@ from sqlalchemy import text def post_view(post: Post | int, variant, stub=False, user_id=None, my_vote=0): if isinstance(post, int): - post = Post.query.get(post) - if not post or post.deleted: - raise Exception('post_not_found') + post = Post.query.filter_by(id=post, deleted=False).one() # Variant 1 - models/post/post.dart if variant == 1: @@ -133,9 +131,7 @@ def cached_user_view_variant_1(user: User, stub=False): # 'user' param can be anyone (including the logged in user), 'user_id' param belongs to the user making the request def user_view(user: User | int, variant, stub=False, user_id=None): if isinstance(user, int): - user = User.query.get(user) - if not user: - raise Exception('user_not_found') + user = User.query.filter_by(id=user).one() # Variant 1 - models/person/person.dart if variant == 1: @@ -190,12 +186,10 @@ def cached_community_view_variant_1(community: Community, stub=False): def community_view(community: Community | int | str, variant, stub=False, user_id=None): if isinstance(community, int): - community = Community.query.get(community) + community = Community.query.filter_by(id=community).one() elif isinstance(community, str): name, ap_domain = community.split('@') - community = Community.query.filter_by(name=name, ap_domain=ap_domain).first() - if not community: - raise Exception('community_not_found') + community = Community.query.filter_by(name=name, ap_domain=ap_domain).one() # Variant 1 - models/community/community.dart if variant == 1: @@ -269,9 +263,7 @@ def calculate_if_has_children(reply): # result used as True / False def reply_view(reply: PostReply | int, variant, user_id=None, my_vote=0): if isinstance(reply, int): - reply = PostReply.query.get(reply) - if not reply: - raise Exception('reply_not_found') + reply = PostReply.query.filter_by(id=reply).one() # Variant 1 - models/comment/comment.dart if variant == 1: @@ -393,9 +385,7 @@ def search_view(type): def instance_view(instance: Instance | int, variant): if isinstance(instance, int): - instance = Instance.query.get(instance) - if not instance: - raise Exception('instance_not_found') + instance = Instance.query.filter_by(id=instance).one() if variant == 1: include = ['id', 'domain', 'software', 'version'] diff --git a/app/models.py b/app/models.py index 8ca8ab7b..3c0530ad 100644 --- a/app/models.py +++ b/app/models.py @@ -1041,11 +1041,8 @@ class User(UserMixin, db.Model): {'user_id': self.id, 'type': NOTIF_USER}).scalars()) def encode_jwt_token(self): - try: - payload = {'sub': str(self.id), 'iss': current_app.config['SERVER_NAME'], 'iat': int(time())} - return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') - except Exception as e: - return str(e) + payload = {'sub': str(self.id), 'iss': current_app.config['SERVER_NAME'], 'iat': int(time())} + return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') # mark a post as 'read' for this user def mark_post_as_read(self, post): diff --git a/app/shared/auth.py b/app/shared/auth.py index b7bb4f90..669c7de3 100644 --- a/app/shared/auth.py +++ b/app/shared/auth.py @@ -21,23 +21,21 @@ def log_user_in(input, src): if src == SRC_WEB: username = input.user_name.data password = input.password.data + user = User.query.filter_by(user_name=username, ap_id=None).first() elif src == SRC_API: required(["username_or_email", "password"], input) string_expected(["username_or_email", "password"], input) username = input['username_or_email'] password = input['password'] + user = User.query.filter_by(user_name=username, ap_id=None, deleted=False).one() else: return None - user = User.query.filter_by(user_name=username, ap_id=None).first() - - if user is None or user.deleted: - if src == SRC_WEB: + if src == SRC_WEB: + if user is None or user.deleted: flash(_('No account exists with that user name.'), 'error') return redirect(url_for('auth.login')) - elif src == SRC_API: - raise Exception('incorrect_login') if not user.check_password(password): if src == SRC_WEB: @@ -96,13 +94,9 @@ def log_user_in(input, src): response.set_cookie('low_bandwidth', '0', expires=datetime(year=2099, month=12, day=30)) return response elif src == SRC_API: - token = user.encode_jwt_token() - if token: - login_json = { - "jwt": token, - "registration_created": user.verified, - "verify_email_sent": True - } - return login_json - else: - raise Exception('could_not_generate_token') + login_json = { + 'jwt': user.encode_jwt_token(), + 'registration_created': user.verified, + 'verify_email_sent': True + } + return login_json diff --git a/app/shared/community.py b/app/shared/community.py index a226ccb3..94aff35e 100644 --- a/app/shared/community.py +++ b/app/shared/community.py @@ -17,9 +17,7 @@ SRC_API = 3 # call from admin.federation not tested def join_community(community_id: int, src, auth=None, user_id=None, main_user_name=True): if src == SRC_API: - community = Community.query.get(community_id) - if not community: - raise Exception('community_not_found') + community = Community.query.filter_by(id=community_id).one() user = authorise_api_user(auth, return_type='model') else: community = Community.query.get_or_404(community_id) @@ -112,9 +110,7 @@ def join_community(community_id: int, src, auth=None, user_id=None, main_user_na # function can be shared between WEB and API (only API calls it for now) def leave_community(community_id: int, src, auth=None): if src == SRC_API: - community = Community.query.get(community_id) - if not community: - raise Exception('community_not_found') + community = Community.query.filter_by(id=community_id).one() user = authorise_api_user(auth, return_type='model') else: community = Community.query.get_or_404(community_id) diff --git a/app/shared/post.py b/app/shared/post.py index 91cd5d3d..c577f4a2 100644 --- a/app/shared/post.py +++ b/app/shared/post.py @@ -20,9 +20,7 @@ SRC_API = 3 def vote_for_post(post_id: int, vote_direction, src, auth=None): if src == SRC_API: - post = Post.query.get(post_id) - if not post: - raise Exception('post_not_found') + post = Post.query.filter_by(id=post_id).one() user = authorise_api_user(auth, return_type='model') else: post = Post.query.get_or_404(post_id) @@ -97,9 +95,7 @@ def vote_for_post(post_id: int, vote_direction, src, auth=None): # post_bookmark in app/post/routes would just need to do 'return bookmark_the_post(post_id, SRC_WEB)' def bookmark_the_post(post_id: int, src, auth=None): if src == SRC_API: - post = Post.query.get(post_id) - if not post or post.deleted: - raise Exception('post_not_found') + post = Post.query.filter_by(id=post_id, deleted=False).one() user_id = authorise_api_user(auth) else: post = Post.query.get_or_404(post_id) @@ -127,9 +123,7 @@ def bookmark_the_post(post_id: int, src, auth=None): # post_remove_bookmark in app/post/routes would just need to do 'return remove_the_bookmark_from_post(post_id, SRC_WEB)' def remove_the_bookmark_from_post(post_id: int, src, auth=None): if src == SRC_API: - post = Post.query.get(post_id) - if not post or post.deleted: - raise Exception('post_not_found') + post = Post.query.filter_by(id=post_id, deleted=False).one() user_id = authorise_api_user(auth) else: post = Post.query.get_or_404(post_id) @@ -156,9 +150,7 @@ def remove_the_bookmark_from_post(post_id: int, src, auth=None): def toggle_post_notification(post_id: int, src, auth=None): # Toggle whether the current user is subscribed to notifications about top-level replies to this post or not if src == SRC_API: - post = Post.query.get(post_id) - if not post or post.deleted: - raise Exception('post_not_found') + post = Post.query.filter_by(id=post_id, deleted=False).one() user_id = authorise_api_user(auth) else: post = Post.query.get_or_404(post_id) @@ -173,10 +165,8 @@ def toggle_post_notification(post_id: int, src, auth=None): db.session.delete(existing_notification) db.session.commit() else: # no subscription yet, so make one - new_notification = NotificationSubscription(name=shorten_string(_('Replies to my post %(post_title)s', - post_title=post.title)), - user_id=user_id, entity_id=post.id, - type=NOTIF_POST) + new_notification = NotificationSubscription(name=shorten_string(_('Replies to my post %(post_title)s', post_title=post.title)), + user_id=user_id, entity_id=post.id, type=NOTIF_POST) db.session.add(new_notification) db.session.commit() diff --git a/app/shared/reply.py b/app/shared/reply.py index 04f8b324..357a5875 100644 --- a/app/shared/reply.py +++ b/app/shared/reply.py @@ -21,9 +21,7 @@ SRC_API = 3 def vote_for_reply(reply_id: int, vote_direction, src, auth=None): if src == SRC_API: - reply = PostReply.query.get(reply_id) - if not reply: - raise Exception('reply_not_found') + reply = PostReply.query.filter_by(id=reply_id).one() user = authorise_api_user(auth, return_type='model') else: reply = PostReply.query.get_or_404(reply_id) @@ -98,9 +96,7 @@ def vote_for_reply(reply_id: int, vote_direction, src, auth=None): # post_reply_bookmark in app/post/routes would just need to do 'return bookmark_the_post_reply(comment_id, SRC_WEB)' def bookmark_the_post_reply(comment_id: int, src, auth=None): if src == SRC_API: - post_reply = PostReply.query.get(comment_id) - if not post_reply or post_reply.deleted: - raise Exception('comment_not_found') + post_reply = PostReply.query.filter_by(id=comment_id, deleted=False).one() user_id = authorise_api_user(auth) else: post_reply = PostReply.query.get_or_404(comment_id) @@ -129,9 +125,7 @@ def bookmark_the_post_reply(comment_id: int, src, auth=None): # post_reply_remove_bookmark in app/post/routes would just need to do 'return remove_the_bookmark_from_post_reply(comment_id, SRC_WEB)' def remove_the_bookmark_from_post_reply(comment_id: int, src, auth=None): if src == SRC_API: - post_reply = PostReply.query.get(comment_id) - if not post_reply or post_reply.deleted: - raise Exception('comment_not_found') + post_reply = PostReply.query.filter_by(id=comment_id, deleted=False).one() user_id = authorise_api_user(auth) else: post_reply = PostReply.query.get_or_404(comment_id) @@ -158,9 +152,7 @@ def remove_the_bookmark_from_post_reply(comment_id: int, src, auth=None): def toggle_post_reply_notification(post_reply_id: int, src, auth=None): # Toggle whether the current user is subscribed to notifications about replies to this reply or not if src == SRC_API: - post_reply = PostReply.query.get(post_reply_id) - if not post_reply or post_reply.deleted: - raise Exception('comment_not_found') + post_reply = PostReply.query.filter_by(id=post_reply_id, deleted=False).one() user_id = authorise_api_user(auth) else: post_reply = PostReply.query.get_or_404(post_reply_id) @@ -226,9 +218,7 @@ def make_reply(input, post, parent_id, src, auth=None): language_id = input.language_id.data if parent_id: - parent_reply = PostReply.query.get(parent_id) - if not parent_reply: - raise Exception('parent_reply_not_found') + parent_reply = PostReply.query.filter_by(id=parent_id).one() else: parent_reply = None @@ -375,7 +365,7 @@ def edit_reply(input, reply, post, src, auth=None): flash(_('Your changes have been saved.'), 'success') if reply.parent_id: - in_reply_to = PostReply.query.get(reply.parent_id) + in_reply_to = PostReply.query.filter_by(id=reply.parent_id).one() else: in_reply_to = post diff --git a/app/utils.py b/app/utils.py index de9cbebf..23460d10 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1250,23 +1250,17 @@ def authorise_api_user(auth, return_type=None, id_match=None): raise Exception('incorrect_login') token = auth[7:] # remove 'Bearer ' - try: - decoded = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=["HS256"]) - if decoded: - user_id = decoded['sub'] - issued_at = decoded['iat'] # use to check against blacklisted JWTs - user = User.query.filter_by(id=user_id, ap_id=None, verified=True, banned=False, deleted=False).scalar() - if user: - if id_match and user.id != id_match: - raise Exception('incorrect_login') - if return_type and return_type == 'model': - return user - else: - return user.id - else: - raise Exception('incorrect_login') - except jwt.InvalidTokenError: - raise Exception('invalid_token') + decoded = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=["HS256"]) + if decoded: + user_id = decoded['sub'] + issued_at = decoded['iat'] # use to check against blacklisted JWTs + user = User.query.filter_by(id=user_id, ap_id=None, verified=True, banned=False, deleted=False).one() + if id_match and user.id != id_match: + raise Exception('incorrect_login') + if return_type and return_type == 'model': + return user + else: + return user.id @cache.memoize(timeout=86400)