diff --git a/app/activitypub/util.py b/app/activitypub/util.py index 23e9a329..2367a266 100644 --- a/app/activitypub/util.py +++ b/app/activitypub/util.py @@ -1520,9 +1520,22 @@ def create_post(activity_log: ActivityPubLog, community: Community, request_json def notify_about_post(post: Post): + + # Send notifications based on subscriptions to the author + notifications_sent_to = set() + for notify_id in post.author.notification_subscribers(): + new_notification = Notification(title=shorten_string(post.title, 50), url=f"/post/{post.id}", + user_id=notify_id, author_id=post.user_id) + db.session.add(new_notification) + user = User.query.get(notify_id) + user.unread_notifications += 1 + db.session.commit() + notifications_sent_to.add(notify_id) + + # Send notifications based on subscriptions to the community people_to_notify = CommunityMember.query.filter_by(community_id=post.community_id, notify_new_posts=True, is_banned=False) for person in people_to_notify: - if person.user_id != post.user_id: + if person.user_id != post.user_id and person.user_id not in notifications_sent_to: new_notification = Notification(title=shorten_string(post.title, 50), url=f"/post/{post.id}", user_id=person.user_id, author_id=post.user_id) db.session.add(new_notification) user = User.query.get(person.user_id) # todo: make this more efficient by doing a join with CommunityMember at the start of the function diff --git a/app/models.py b/app/models.py index 455ad0b7..ff27b01b 100644 --- a/app/models.py +++ b/app/models.py @@ -889,8 +889,8 @@ class User(UserMixin, db.Model): # ids of all the users who want to be notified when self makes a post def notification_subscribers(self): - return db.session.execute(text('SELECT user_id FROM "notification_subscription" WHERE entity_id = :user_id AND type = :type '), - {'user_id': self.id, 'type': NOTIF_USER}).scalars() + return list(db.session.execute(text('SELECT user_id FROM "notification_subscription" WHERE entity_id = :user_id AND type = :type '), + {'user_id': self.id, 'type': NOTIF_USER}).scalars()) class ActivityLog(db.Model):