replace requests with httpx #15

better thread safety
This commit is contained in:
rimu 2024-09-15 19:30:45 +12:00
parent d0cbf592ea
commit e616ce122f
11 changed files with 73 additions and 130 deletions

View file

@ -14,6 +14,7 @@ from flask_babel import Babel, lazy_gettext as _l
from flask_caching import Cache
from celery import Celery
from sqlalchemy_searchable import make_searchable
import httpx
from config import Config
@ -41,6 +42,7 @@ bootstrap = Bootstrap5()
babel = Babel(locale_selector=get_locale)
cache = Cache()
celery = Celery(__name__, broker=Config.CELERY_BROKER_URL)
httpx_client = httpx.Client(http2=True, limits=httpx.Limits(max_connections=20))
def create_app(config_class=Config):

View file

@ -34,7 +34,7 @@ import json
from typing import Literal, TypedDict, cast
from urllib.parse import urlparse
import requests
import httpx
import arrow
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
@ -44,7 +44,7 @@ from datetime import datetime
from dateutil import parser
from pyld import jsonld
from email.utils import formatdate
from app import db, celery
from app import db, celery, httpx_client
from app.constants import DATETIME_MS_FORMAT
from app.models import utcnow, ActivityPubLog, Community, Instance, CommunityMember, User
from sqlalchemy import text
@ -391,21 +391,17 @@ class HttpSignature:
# Send the request with all those headers except the pseudo one
del headers["(request-target)"]
try:
response = requests.request(
response = httpx_client.request(
method,
uri,
headers=headers,
data=body_bytes,
timeout=timeout,
allow_redirects=method == "GET",
follow_redirects=method == "GET",
)
except requests.exceptions.SSLError as invalid_cert:
# Not our problem if the other end doesn't have proper SSL
current_app.logger.info(f"{uri} {invalid_cert}")
raise requests.exceptions.SSLError from invalid_cert
except ValueError as ex:
except httpx.HTTPError as ex:
# Convert to a more generic error we handle
raise requests.exceptions.RequestException(f"InvalidCodepoint: {str(ex)}") from None
raise httpx.HTTPError(f"HTTP Exception for {ex.request.url} - {ex}") from None
if (
method == "POST"

View file

@ -6,6 +6,7 @@ from datetime import timedelta, datetime, timezone
from random import randint
from typing import Union, Tuple, List
import httpx
import redis
from flask import current_app, request, g, url_for, json
from flask_babel import _
@ -17,10 +18,6 @@ from app.models import User, Post, Community, BannedInstances, File, PostReply,
Language, Tag, Poll, PollChoice, UserFollower, CommunityBan, CommunityJoinRequest, NotificationSubscription
from app.activitypub.signature import signed_get_request, post_request
import time
import base64
import requests
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
from app.constants import *
from urllib.parse import urlparse, parse_qs
from PIL import Image, ImageOps
@ -93,39 +90,6 @@ def local_communities():
return db.session.execute(text('SELECT COUNT(id) as c FROM "community" WHERE instance_id = 1')).scalar()
def send_activity(sender: User, host: str, content: str):
date = time.strftime('%a, %d %b %Y %H:%M:%S UTC', time.gmtime())
private_key = serialization.load_pem_private_key(sender.private_key, password=None)
# todo: look up instance details to set host_inbox
host_inbox = '/inbox'
signed_string = f"(request-target): post {host_inbox}\nhost: {host}\ndate: " + date
signature = private_key.sign(signed_string.encode('utf-8'), padding.PKCS1v15(), hashes.SHA256())
encoded_signature = base64.b64encode(signature).decode('utf-8')
# Construct the Signature header
header = f'keyId="https://{current_app.config["SERVER_NAME"]}/u/{sender.user_name}",headers="(request-target) host date",signature="{encoded_signature}"'
# Create headers for the request
headers = {
'Host': host,
'Date': date,
'Signature': header
}
# Make the HTTP request
try:
response = requests.post(f'https://{host}{host_inbox}', headers=headers, data=content,
timeout=REQUEST_TIMEOUT)
except requests.exceptions.RequestException:
time.sleep(1)
response = requests.post(f'https://{host}{host_inbox}', headers=headers, data=content,
timeout=REQUEST_TIMEOUT / 2)
return response.status_code
def post_to_activity(post: Post, community: Community):
# local PieFed posts do not have a create or announce id
create_id = post.ap_create_id if post.ap_create_id else f"https://{current_app.config['SERVER_NAME']}/activities/create/{gibberish(15)}"
@ -341,14 +305,12 @@ def find_actor_or_create(actor: str, create_if_not_found=True, community_only=Fa
if not signed_get:
try:
actor_data = get_request(actor_url, headers={'Accept': 'application/activity+json'})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
time.sleep(randint(3, 10))
try:
actor_data = get_request(actor_url, headers={'Accept': 'application/activity+json'})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
return None
except requests.exceptions.ConnectionError:
return None
if actor_data.status_code == 200:
try:
actor_json = actor_data.json()
@ -379,7 +341,7 @@ def find_actor_or_create(actor: str, create_if_not_found=True, community_only=Fa
try:
webfinger_data = get_request(f"https://{server}/.well-known/webfinger",
params={'resource': f"acct:{address}@{server}"})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
time.sleep(randint(3, 10))
webfinger_data = get_request(f"https://{server}/.well-known/webfinger",
params={'resource': f"acct:{address}@{server}"})
@ -392,7 +354,7 @@ def find_actor_or_create(actor: str, create_if_not_found=True, community_only=Fa
# retrieve the activitypub profile
try:
actor_data = get_request(links['href'], headers={'Accept': type})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
time.sleep(randint(3, 10))
actor_data = get_request(links['href'], headers={'Accept': type})
# to see the structure of the json contained in actor_data, do a GET to https://lemmy.world/c/technology with header Accept: application/activity+json
@ -487,11 +449,11 @@ def refresh_user_profile_task(user_id):
if user and user.instance_id and user.instance.online():
try:
actor_data = get_request(user.ap_public_url, headers={'Accept': 'application/activity+json'})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
time.sleep(randint(3, 10))
try:
actor_data = get_request(user.ap_public_url, headers={'Accept': 'application/activity+json'})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
return
except:
try:
@ -567,7 +529,7 @@ def refresh_community_profile_task(community_id):
if community and community.instance.online() and not community.is_local():
try:
actor_data = get_request(community.ap_public_url, headers={'Accept': 'application/activity+json'})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
time.sleep(randint(3, 10))
try:
actor_data = get_request(community.ap_public_url, headers={'Accept': 'application/activity+json'})
@ -1211,7 +1173,7 @@ def new_instance_profile_task(instance_id: int):
try:
instance_json = instance_data.json()
instance_data.close()
except requests.exceptions.JSONDecodeError as ex:
except Exception as ex:
instance_json = {}
if 'type' in instance_json and instance_json['type'] == 'Application':
instance.inbox = instance_json['inbox']
@ -1258,11 +1220,9 @@ def new_instance_profile_task(instance_id: int):
instance.updated_at = utcnow()
db.session.commit()
HEADERS = {'User-Agent': 'PieFed/1.0', 'Accept': 'application/activity+json'}
headers = {'User-Agent': 'PieFed/1.0', 'Accept': 'application/activity+json'}
try:
nodeinfo = requests.get(f"https://{instance.domain}/.well-known/nodeinfo", headers=HEADERS,
timeout=5, allow_redirects=True)
nodeinfo = get_request(f"https://{instance.domain}/.well-known/nodeinfo", headers=headers)
if nodeinfo.status_code == 200:
nodeinfo_json = nodeinfo.json()
for links in nodeinfo_json['links']:
@ -1272,8 +1232,7 @@ def new_instance_profile_task(instance_id: int):
links['rel'] == 'http://nodeinfo.diaspora.software/ns/schema/2.1'): # Lemmy v0.19.4+ (no 2.0 back-compat provided here)
try:
time.sleep(0.1)
node = requests.get(links['href'], headers=HEADERS, timeout=5,
allow_redirects=True)
node = get_request(links['href'], headers=headers)
if node.status_code == 200:
node_json = node.json()
if 'software' in node_json:

View file

@ -1,7 +1,5 @@
import os
from datetime import datetime, timedelta
from io import BytesIO
import requests as r
from datetime import timedelta
from time import sleep
from flask import request, flash, json, url_for, current_app, redirect, g, abort

View file

@ -1,10 +1,11 @@
import logging
import requests
from flask import Markup, current_app, request, session
from wtforms import ValidationError
from wtforms.fields import HiddenField
from wtforms.widgets import HiddenInput
from app import httpx_client
logger = logging.getLogger(__name__)
RECAPTCHA_TEMPLATE = '''
@ -83,7 +84,7 @@ class Recaptcha3Validator(object):
'response': response
}
http_response = requests.post(RECAPTCHA_VERIFY_SERVER, data, timeout=10)
http_response = httpx_client.post(RECAPTCHA_VERIFY_SERVER, data, timeout=10)
if http_response.status_code != 200:
return False

View file

@ -1,13 +1,10 @@
import random
from datetime import timedelta
from unicodedata import normalize
import requests
from flask import current_app
import app
from app import cache
from app.models import utcnow
from app.utils import get_request
# Return a random string of 6 letter/digits.
@ -31,7 +28,7 @@ def ip2location(ip: str):
if not current_app.config['IPINFO_TOKEN']:
return {}
url = 'http://ipinfo.io/' + ip + '?token=' + current_app.config['IPINFO_TOKEN']
response = requests.get(url, timeout=5)
response = get_request(url)
if response.status_code == 200:
data = response.json()
cache.set('ip_' + ip, data, timeout=86400)

View file

@ -7,7 +7,7 @@ from random import randint
from time import sleep
import flask
import requests
import httpx
from flask import json, current_app
from flask_babel import _
from sqlalchemy import or_, desc, text
@ -214,8 +214,7 @@ def register(app):
# 'solves' this by redirecting calls for nodeinfo/2.0.json to nodeinfo/2.1
if not nodeinfo_href:
try:
nodeinfo = requests.get(f"https://{instance.domain}/.well-known/nodeinfo", headers=HEADERS,
timeout=5, allow_redirects=True)
nodeinfo = get_request(f"https://{instance.domain}/.well-known/nodeinfo", headers=HEADERS)
if nodeinfo.status_code == 200:
nodeinfo_json = nodeinfo.json()
@ -234,16 +233,12 @@ def register(app):
instance.failures += 1
elif nodeinfo.status_code >= 400:
current_app.logger.info(f"{instance.domain} has no well-known/nodeinfo response")
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
instance.failures += 1
except requests.exceptions.ConnectionError:
instance.failures += 1
except requests.exceptions.RequestException:
pass
if instance.nodeinfo_href:
try:
node = requests.get(instance.nodeinfo_href, headers=HEADERS, timeout=5, allow_redirects=True)
node = get_request(instance.nodeinfo_href, headers=HEADERS)
if node.status_code == 200:
node_json = node.json()
if 'software' in node_json:
@ -254,7 +249,7 @@ def register(app):
elif node.status_code >= 400:
instance.failures += 1
instance.nodeinfo_href = None
except requests.exceptions.RequestException:
except httpx.HTTPError:
instance.failures += 1
instance.nodeinfo_href = None
if instance.failures > 7 and instance.dormant == True:

View file

@ -2,7 +2,8 @@ from datetime import datetime, timedelta
from time import sleep
from random import randint
from typing import List
import requests
import httpx
from PIL import Image, ImageOps
from flask import request, abort, g, current_app, json
from flask_login import current_user
@ -47,15 +48,14 @@ def search_for_community(address: str):
try:
webfinger_data = get_request(f"https://{server}/.well-known/webfinger",
params={'resource': f"acct:{address[1:]}"})
except requests.exceptions.ReadTimeout:
except httpx.HTTPError:
sleep(randint(3, 10))
try:
webfinger_data = get_request(f"https://{server}/.well-known/webfinger",
params={'resource': f"acct:{address[1:]}"})
except requests.exceptions.RequestException:
except httpx.HTTPError:
return None
except requests.exceptions.RequestException:
return None
if webfinger_data.status_code == 200:
webfinger_json = webfinger_data.json()
for links in webfinger_json['links']:

View file

@ -22,7 +22,7 @@ from app.utils import render_template, get_setting, request_etag_matches, return
ap_datetime, shorten_string, markdown_to_text, user_filters_home, \
joined_communities, moderating_communities, markdown_to_html, allowlist_html, \
blocked_instances, communities_banned_from, topic_tree, recently_upvoted_posts, recently_downvoted_posts, \
blocked_users, menu_topics, languages_for_form, blocked_communities
blocked_users, menu_topics, languages_for_form, blocked_communities, get_request
from app.models import Community, CommunityMember, Post, Site, User, utcnow, Topic, Instance, \
Notification, Language, community_language, ModLog
@ -448,6 +448,13 @@ def list_files(directory):
@bp.route('/test')
def test():
response = get_request('https://rimu.geek.nz')
x = ''
if response.status_code == 200:
x =response.content
response.close()
return x
json = {
"@context": "https://www.w3.org/ns/activitystreams",
"actor": "https://ioc.exchange/users/haiviittech",

View file

@ -3,7 +3,6 @@ from time import time
from typing import List, Union
from urllib.parse import urlparse, parse_qs, urlencode, urlunparse
import requests
from flask import current_app, escape, url_for, render_template_string
from flask_login import UserMixin, current_user
from sqlalchemy import or_, text, desc
@ -15,7 +14,7 @@ from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.ext.mutable import MutableList
from flask_sqlalchemy import BaseQuery
from sqlalchemy_searchable import SearchQueryMixin
from app import db, login, cache, celery
from app import db, login, cache, celery, httpx_client
import jwt
import os
import math
@ -342,7 +341,7 @@ def flush_cdn_cache_task(to_purge: Union[str, List[str]]):
}
if body:
response = requests.request(
response = httpx_client.request(
'POST',
f'https://api.cloudflare.com/client/v4/zones/{zone_id}/purge_cache',
headers=headers,

View file

@ -11,6 +11,7 @@ from datetime import datetime, timedelta, date
from time import sleep
from typing import List, Literal, Union
import httpx
import markdown2
import math
from urllib.parse import urlparse, parse_qs, urlencode
@ -19,10 +20,8 @@ import flask
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
import warnings
from app.activitypub.signature import default_context
warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning)
import requests
import os
from flask import current_app, json, redirect, url_for, request, make_response, Response, g, flash
from flask_babel import _
@ -30,7 +29,7 @@ from flask_login import current_user, logout_user
from sqlalchemy import text, or_
from wtforms.fields import SelectField, SelectMultipleField
from wtforms.widgets import Select, html_params, ListWidget, CheckboxInput
from app import db, cache
from app import db, cache, httpx_client
import re
from moviepy.editor import VideoFileClip
from PIL import Image, ImageOps
@ -81,7 +80,7 @@ def getmtime(filename):
# do a GET request to a uri, return the result
def get_request(uri, params=None, headers=None) -> requests.Response:
def get_request(uri, params=None, headers=None) -> httpx.Response:
timeout = 15 if 'washingtonpost.com' in uri else 5 # Washington Post is really slow on og:image for some reason
if headers is None:
headers = {'User-Agent': 'PieFed/1.0'}
@ -92,50 +91,39 @@ def get_request(uri, params=None, headers=None) -> requests.Response:
else:
payload_str = urllib.parse.urlencode(params) if params else None
try:
response = requests.get(uri, params=payload_str, headers=headers, timeout=timeout, allow_redirects=True)
except requests.exceptions.SSLError as invalid_cert:
# Not our problem if the other end doesn't have proper SSL
current_app.logger.info(f"{uri} {invalid_cert}")
raise requests.exceptions.SSLError from invalid_cert
response = httpx_client.get(uri, params=payload_str, headers=headers, timeout=timeout, follow_redirects=True)
except ValueError as ex:
# Convert to a more generic error we handle
raise requests.exceptions.RequestException(f"InvalidCodepoint: {str(ex)}") from None
except requests.exceptions.ReadTimeout as read_timeout:
raise httpx.HTTPError(f"HTTPError: {str(ex)}") from None
except httpx.ReadError as connection_error:
try: # retry, this time with a longer timeout
sleep(random.randint(3, 10))
response = requests.get(uri, params=payload_str, headers=headers, timeout=timeout * 2, allow_redirects=True)
except Exception as e:
current_app.logger.info(f"{uri} {read_timeout}")
raise requests.exceptions.ReadTimeout from read_timeout
except requests.exceptions.ConnectionError as connection_error:
try: # retry, this time with a longer timeout
sleep(random.randint(3, 10))
response = requests.get(uri, params=payload_str, headers=headers, timeout=timeout * 2, allow_redirects=True)
response = httpx_client.get(uri, params=payload_str, headers=headers, timeout=timeout * 2, follow_redirects=True)
except Exception as e:
current_app.logger.info(f"{uri} {connection_error}")
raise requests.exceptions.ConnectionError from connection_error
raise httpx_client.ReadError from connection_error
except httpx.HTTPError as read_timeout:
try: # retry, this time with a longer timeout
sleep(random.randint(3, 10))
response = httpx_client.get(uri, params=payload_str, headers=headers, timeout=timeout * 2, follow_redirects=True)
except Exception as e:
current_app.logger.info(f"{uri} {read_timeout}")
raise httpx.HTTPError from read_timeout
return response
# do a HEAD request to a uri, return the result
def head_request(uri, params=None, headers=None) -> requests.Response:
def head_request(uri, params=None, headers=None) -> httpx.Response:
if headers is None:
headers = {'User-Agent': 'PieFed/1.0'}
else:
headers.update({'User-Agent': 'PieFed/1.0'})
try:
response = requests.head(uri, params=params, headers=headers, timeout=5, allow_redirects=True)
except requests.exceptions.SSLError as invalid_cert:
# Not our problem if the other end doesn't have proper SSL
current_app.logger.info(f"{uri} {invalid_cert}")
raise requests.exceptions.SSLError from invalid_cert
except ValueError as ex:
# Convert to a more generic error we handle
raise requests.exceptions.RequestException(f"InvalidCodepoint: {str(ex)}") from None
except requests.exceptions.ReadTimeout as read_timeout:
current_app.logger.info(f"{uri} {read_timeout}")
raise requests.exceptions.ReadTimeout from read_timeout
response = httpx_client.head(uri, params=params, headers=headers, timeout=5, allow_redirects=True)
except httpx.HTTPError as er:
current_app.logger.info(f"{uri} {er}")
raise httpx.HTTPError from er
return response
@ -227,14 +215,14 @@ def is_video_hosting_site(url: str) -> bool:
def mime_type_using_head(url):
# Find the mime type of a url by doing a HEAD request - this is the same as GET except only the HTTP headers are transferred
try:
response = requests.head(url, timeout=5)
response = httpx_client.head(url, timeout=5)
response.raise_for_status() # Raise an exception for HTTP errors
content_type = response.headers.get('Content-Type')
if content_type:
return content_type
else:
return ''
except requests.exceptions.RequestException as e:
except httpx.HTTPError as e:
return ''
@ -525,7 +513,7 @@ def blocked_referrers() -> List[str]:
def retrieve_block_list():
try:
response = requests.get('https://raw.githubusercontent.com/rimu/no-qanon/master/domains.txt', timeout=1)
response = httpx_client.get('https://raw.githubusercontent.com/rimu/no-qanon/master/domains.txt', timeout=1)
except:
return None
if response and response.status_code == 200:
@ -534,7 +522,7 @@ def retrieve_block_list():
def retrieve_peertube_block_list():
try:
response = requests.get('https://peertube_isolation.frama.io/list/peertube_isolation.json', timeout=1)
response = httpx_client.get('https://peertube_isolation.frama.io/list/peertube_isolation.json', timeout=1)
except:
return None
list = ''
@ -542,6 +530,7 @@ def retrieve_peertube_block_list():
response_data = response.json()
for row in response_data['data']:
list += row['value'] + "\n"
response.close()
return list.strip()
@ -951,7 +940,7 @@ def opengraph_parse(url):
def url_to_thumbnail_file(filename) -> File:
try:
timeout = 15 if 'washingtonpost.com' in filename else 5 # Washington Post is really slow for some reason
response = requests.get(filename, timeout=timeout)
response = httpx_client.get(filename, timeout=timeout)
except:
return None
if response.status_code == 200:
@ -1118,7 +1107,7 @@ def in_sorted_list(arr, target):
# Makes a still image from a video url, without downloading the whole video file
def generate_image_from_video_url(video_url, output_path, length=2):
response = requests.get(video_url, stream=True, timeout=5,
response = httpx_client.get(video_url, stream=True, timeout=5,
headers={'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:127.0) Gecko/20100101 Firefox/127.0'}) # Imgur requires a user agent
content_type = response.headers.get('Content-Type')
if content_type: