Source code for atproto_firehose

"""ATProto firehose client. Enqueues receive tasks for events for bridged users.

https://atproto.com/specs/event-stream
https://atproto.com/specs/sync#firehose
"""
from collections import namedtuple
from datetime import datetime, timedelta
from io import BytesIO
import itertools
import logging
import os
from queue import Queue
from threading import Event, Lock, Thread, Timer
import threading
import time

from arroba.datastore_storage import AtpRepo
from arroba.util import at_uri, parse_at_uri
import dag_json
from google.cloud import ndb
from google.cloud.ndb.exceptions import ContextError
from lexrpc.base import AT_URI_RE
from lexrpc.client import Client
from lexrpc import ValidationError
import libipld
from multiformats import CID
from webutil import util
from webutil.appengine_config import ndb_client
from webutil.appengine_info import DEBUG
from webutil.util import json_dumps, json_loads

from atproto import ATProto, Cursor, DatastoreClient
from common import (
    BETA_USER_IDS,
    create_task,
    NDB_CONTEXT_KWARGS,
    report_error,
    report_exception,
    USER_AGENT,
)
from domains import PROTOCOL_DOMAINS
from models import Object, Target
from protocol import DELETE_TASK_DELAY
from web import Web

logger = logging.getLogger(__name__)

_validator = Client()

RECONNECT_DELAY = timedelta(seconds=30)
STORE_CURSOR_FREQ = timedelta(seconds=10)
LOG_OUTLIER_THRESHOLD = timedelta(minutes=5)

# a commit operation. similar to arroba.repo.Write. record is None for deletes.
Op = namedtuple('Op', ['action', 'repo', 'path', 'seq', 'record', 'time'],
                # last four fields are optional
                defaults=[None, None, None, None])

# contains Ops
#
# maxsize is important here! if we hit this limit, subscribe will block when it
# tries to add more commits until handle consumes some. this keeps subscribe
# from getting too far ahead of handle and using too much memory in this queue.
commits = Queue(maxsize=1000)

# global so that subscribe can reuse it across calls
cursor = None

# global: _load_dids populates them, subscribe and handle use them
atproto_dids = set()  # native ATProto accounts that are bridged
atproto_loaded_at = datetime(1900, 1, 1)
bridged_dids = set()  # accounts elsewhere that are bridged into ATProto
bridged_loaded_at = datetime(1900, 1, 1)
protocol_bot_dids = set()
dids_initialized = Event()


def load_dids():
    logger.info('Starting _load_dids timer')
    # run in a separate thread since it needs to make its own NDB
    # context when it runs in the timer thread
    Thread(target=_load_dids, daemon=True).start()
    dids_initialized.wait()
    dids_initialized.clear()


def _load_dids():
    global atproto_dids, atproto_loaded_at, bridged_dids, bridged_loaded_at

    if not DEBUG:
        Timer(STORE_CURSOR_FREQ.total_seconds(), _load_dids).start()

    with ndb_client.context(**NDB_CONTEXT_KWARGS):
        try:
            atproto_query = ATProto.query(ATProto.status == None,
                                          ATProto.enabled_protocols != None,
                                          ATProto.updated > atproto_loaded_at)
            loaded_at = ATProto.query().order(-ATProto.updated).get().updated
            new_atproto = [key.id() for key in atproto_query.iter(keys_only=True)]
            atproto_dids.update(new_atproto)
            # set *after* we populate atproto_dids so that if we crash earlier, we
            # re-query from the earlier timestamp
            atproto_loaded_at = loaded_at

            bridged_query = AtpRepo.query(AtpRepo.status == None,
                                          AtpRepo.updated > bridged_loaded_at)
            loaded_at = AtpRepo.query().order(-AtpRepo.created).get().created
            new_bridged = [key.id() for key in bridged_query.iter(keys_only=True)]
            bridged_dids.update(new_bridged)
            # set *after* we populate bridged_dids so that if we crash earlier, we
            # re-query from the earlier timestamp
            bridged_loaded_at = loaded_at

            if not protocol_bot_dids:
                bot_keys = [Web(id=domain).key for domain in PROTOCOL_DOMAINS]
                for bot in ndb.get_multi(bot_keys):
                    if bot:
                        if did := bot.get_copy(ATProto):
                            logger.info(f'Loaded protocol bot user {bot.key.id()} {did}')
                            protocol_bot_dids.add(did)

            dids_initialized.set()
            total = len(atproto_dids) + len(bridged_dids)
            logger.info(f'DIDs: {total} ATProto {len(atproto_dids)} (+{len(new_atproto)}), AtpRepo {len(bridged_dids)} (+{len(new_bridged)}); commits {commits.qsize()}')

            # TODO: remove
            # temporary, debugging https://github.com/snarfed/bridgy-fed/issues/2327
            import gc, threading, tracemalloc, resource
            logger.info(f'  threads {threading.active_count()} rss {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss} gc gen0={gc.get_count()[0]} gen1={gc.get_count()[1]} gen2={gc.get_count()[2]}')

        except BaseException:
            # eg google.cloud.ndb.exceptions.ContextError when we lose the ndb context
            # https://console.cloud.google.com/errors/detail/CLO6nJnRtKXRyQE?project=bridgy-federated
            report_exception()


[docs] def subscriber(): """Wrapper around :func:`_subscribe` that catches exceptions and reconnects.""" logger.info(f'started thread to subscribe to {os.environ["BGS_HOST"]} firehose') load_dids() while True: try: with ndb_client.context(**NDB_CONTEXT_KWARGS): subscribe() except BaseException: report_exception() logger.info(f'disconnected! waiting {RECONNECT_DELAY} and then reconnecting') time.sleep(RECONNECT_DELAY.total_seconds())
[docs] def subscribe(): """Subscribes to the relay's firehose. Relay hostname comes from the ``BGS_HOST`` environment variable. """ global cursor if not cursor: cursor = Cursor.get_or_insert( f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos') # TODO: remove? does this make us skip events? if we remove it, will we # infinite loop when we fail on an event? if cursor.cursor: cursor.cursor += 1 last_stored_cursor = cur_timestamp = last_timestamp = None client = Client(f'https://{os.environ["BGS_HOST"]}', headers={'User-Agent': USER_AGENT}) for header, payload in client.com.atproto.sync.subscribeRepos( cursor=cursor.cursor): # parse header if header.get('op') == -1: logger.warning(f'Got error from relay! {payload}') continue t = header.get('t') if t not in ('#commit', '#account', '#identity'): if t not in ('#handle', '#tombstone'): logger.info(f'Got {t} from relay') continue # parse payload repo = payload.get('repo') or payload.get('did') if not repo: logger.warning(f'Payload missing repo! {payload}') continue seq = payload.get('seq') if not seq: logger.warning(f'Payload missing seq! {payload}') continue if cur_timestamp and (not last_timestamp or cur_timestamp > last_timestamp): # prevent time from moving backwards for commits with outlier stale # time values. https://github.com/snarfed/bridgy-fed/issues/2348 last_timestamp = cur_timestamp cur_timestamp = payload['time'] # if we fail processing this commit and raise an exception up to subscriber, # skip it and start with the next commit when we're restarted cursor.cursor = seq + 1 elapsed = util.now().replace(tzinfo=None) - cursor.updated if elapsed > STORE_CURSOR_FREQ: # it's been long enough, update our stored cursor and metrics msg = f'updating stored cursor to {cursor.cursor}' if last_stored_cursor: events_s = int((cursor.cursor - last_stored_cursor) / elapsed.total_seconds()) msg += f', {events_s} events/s' last_stored_cursor = cursor.cursor cur_dt = util.parse_iso8601(cur_timestamp) if last_timestamp: last_dt = util.parse_iso8601(last_timestamp) if last_dt - cur_dt <= LOG_OUTLIER_THRESHOLD: behind = util.now() - cur_dt msg += f', {behind} ({int(behind.total_seconds())} s) behind' logger.info(msg) cursor.put() # when running locally, comment out put above and uncomment this # cursor.updated = util.now().replace(tzinfo=None) if t in ('#account', '#identity'): if repo in atproto_dids or repo in bridged_dids: t = t.removeprefix('#') logger.info(f'Got {t} {repo}') commits.put(Op(action=t, repo=repo, seq=seq, time=cur_timestamp)) continue blocks = {} # maps base32 str CID to dict block if block_bytes := payload.get('blocks'): # WARNING: in libipld v2+ (which we're on), this returns CIDs as raw # bytes! we base32-encode those to strings later, in _handle_commit_op, # via _encode_bytes_cids background: # https://github.com/snarfed/bridgy-fed/issues/1316 try: _, blocks = libipld.decode_car(block_bytes) except (TypeError, ValueError) as e: report_error(f'failed decoding blocks! skipping seq {seq} {repo} {payload.get("ops")} {block_bytes}', exception=True) continue # detect records from bridged ATProto users that we should handle for p_op in payload.get('ops', []): op = Op(repo=payload['repo'], action=p_op.get('action'), path=p_op.get('path'), seq=payload['seq'], time=payload['time']) if not op.action or not op.path: logger.info( f'bad payload! seq {op.seq} action {op.action} path {op.path}!') continue if op.repo in atproto_dids and op.action == 'delete': # TODO: also detect deletes of records that *reference* our bridged # users, eg a delete of a follow or like or repost of them. # not easy because we need to getRecord the record to check commits.put(op) continue cid = p_op.get('cid') block = blocks.get(cid) # our own commits are sometimes missing the record # https://github.com/snarfed/bridgy-fed/issues/1016 if not cid or not block: continue elif not isinstance(block, dict): # https://github.com/snarfed/bridgy-fed/issues/1938 logger.info(f"Skipping odd record we couldn't understand (#1938): {op} {p_op} {repr(block)}") continue op = op._replace(record=block) type = op.record.get('$type') if not type: logger.warning(f'commit record missing $type! {op.action} {op.repo} {op.path} {cid}') logger.warning(dag_json.encode(op.record).decode()) continue elif (type not in ATProto.SUPPORTED_RECORD_TYPES and not (op.repo in BETA_USER_IDS and type in ATProto.SUPPORTED_RECORD_TYPES_BETA_USERS) and type not in ATProto.STORE_RECORD_TYPES): continue def is_ours(did_or_ref, native): """Returns True if the arg is a bridged user. Args: did_or_ref (str or dict): if dict, a ``com.atproto.repo.strongRef`` or similar native (bool): if True, bridged ATProto users also count. If False, only users from other protocols who are bridged into ATProto count """ did = None if isinstance(did_or_ref, dict): if match := AT_URI_RE.match(did_or_ref['uri']): did = match.group('repo') else: did = did_or_ref return did and (did in bridged_dids or native and did in atproto_dids) if op.repo in atproto_dids: # from a bridged Bluesky user if type == 'app.bsky.actor.profile': commits.put(op) elif type == 'app.bsky.feed.repost': if is_ours(op.record['subject'], native=True): commits.put(op) elif type == 'app.bsky.feed.like': if is_ours(op.record['subject'], native=False): commits.put(op) elif type in ('app.bsky.graph.block', 'app.bsky.graph.follow'): if is_ours(op.record['subject'], native=False): commits.put(op) elif type == 'app.bsky.feed.post': reply = op.record.get('reply') if not reply or is_ours(reply['parent'], native=True): commits.put(op) # other lexicon that we support (checked earlier). go ahead and try # to bridge it. else: commits.put(op) elif op.repo not in bridged_dids: # from an unbridged Bluesky user. only follows of protocol bots and # replies/quotes/mentions of bridged users, so that we can DM them a # notification if type == 'app.bsky.graph.follow': if op.record['subject'] in protocol_bot_dids: commits.put(op) elif type == 'app.bsky.feed.post': subjects = [] if reply := op.record.get('reply'): subjects.append(reply.get('parent')) if embed := op.record.get('embed'): if embed.get('$type') == 'app.bsky.embed.record': subjects.append(embed['record']) for facet in op.record.get('facets', []): for feat in facet.get('features', []): if feat.get('$type') == 'app.bsky.richtext.facet#mention': subjects.append(feat.get('did')) for subject in subjects: if is_ours(subject, native=False): commits.put(op) break
[docs] def handler(): """Wrapper around :func:`handle` that catches exceptions and restarts.""" logger.info(f'started handle thread to store objects and enqueue receive tasks') while True: with ndb_client.context(**NDB_CONTEXT_KWARGS): try: handle() # if we return cleanly, that means we hit the limit break except BaseException: report_exception()
# fall through to loop to create new ndb context in case this is # a ContextError # https://console.cloud.google.com/errors/detail/CIvwj_7MmsfOWw;time=P1D;locations=global?project=bridgy-federated def _handle_commit_op(op): """ Args: op (Op) """ at_uri = f'at://{op.repo}/{op.path}' type, _ = op.path.strip('/').split('/', maxsplit=1) record = _encode_bytes_cids(op.record) if type in ATProto.STORE_RECORD_TYPES and op.action in ('create', 'update'): # TODO: handle deletes logger.info(f'Just storing {op.seq} {op.action} {op.repo} {op.path}') assert type not in ATProto.SUPPORTED_RECORD_TYPES, (type, record) Object.get_or_create(at_uri, bsky=record, authed_as=op.repo, source_protocol=ATProto.LABEL) if type == 'community.lexicon.payments.webMonetization': _handle_webMonetization(op) return if type not in ATProto.SUPPORTED_RECORD_TYPES: logger.debug(f'Skipping unsupported type {type}: {at_uri}') return # store object, enqueue receive task if op.action in ('create', 'update'): record_kwarg = {'bsky': record} obj_id = at_uri try: _validator.validate(type, 'record', record) except ValidationError as e: report_error(f'skipping invalid {type} record on firehose: {e}; {op}') return if type == 'site.standard.document': _handle_standard_site_document(op) elif op.action == 'delete': verb = ( 'delete' if type in ('app.bsky.actor.profile', 'app.bsky.feed.post') else 'stop-following' if type == 'app.bsky.graph.follow' else 'undo') obj_id = f'{at_uri}#{verb}' record_kwarg = { 'our_as1': { 'objectType': 'activity', 'verb': verb, 'id': obj_id, 'actor': op.repo, 'object': at_uri, }, } # stop-following object is followee id, not follow activity's id if type == 'app.bsky.graph.follow': if (follow := ATProto.load(at_uri, remote=False)) and follow.bsky: record_kwarg['our_as1']['object'] = follow.bsky['subject'] else: return else: logger.error(f'Unknown action {op.action} for {op.repo} {op.path}') return logger.info(f'Got {op.seq} {op.action} {op.repo} {op.path}') delay = DELETE_TASK_DELAY if op.action == 'delete' else None try: create_task(queue='receive', id=obj_id, source_protocol=ATProto.LABEL, authed_as=op.repo, received_at=op.time, delay=delay, **record_kwarg) # when running locally, comment out above and uncomment this # logger.info(f'enqueuing receive task for {at_uri}') except ContextError: raise # handled in handle() except BaseException: report_error(obj_id, exception=True)
[docs] def handle(limit=None): """ Args: limit: integer (optional): only used in tests """ seen = 0 while op := commits.get(): match op.action: case 'account': # reload DID doc ATProto.load(op.repo, raw=True, remote=True) case 'identity': # reload DID doc, update user's computed handle property, send actor # update to followers ATProto.load(op.repo, raw=True, remote=True) if user := ATProto.get_by_id(op.repo): user.put() if user.obj and user.obj.as1: identity_op = Op(repo=op.repo, action='update', record=user.obj.bsky, path='app.bsky.actor.profile/self', seq=op.seq, time=op.time) try: _handle_commit_op(identity_op) except BaseException: logger.error(f'Error handling op: {identity_op}') raise case _: try: _handle_commit_op(op) except BaseException: logger.error(f'Error handling op: {op}') raise seen += 1 if limit is not None and seen >= limit: return assert False, "handle thread shouldn't reach here!"
def _handle_standard_site_document(op): """Enqueues a delete task for the bskyPostRef post if we've already bridged it. This is for the case when we see a document's post record first, and then later see the document itself. At that point, we've already bridged the post as a post, but really we want to bridge the document *instead* of the post. So, we delete the bridged version(s) of the post, and then bridge the document normally. https://github.com/snarfed/bridgy-fed/issues/2324 Args: op (Op) """ if not (post_uri := op.record.get('bskyPostRef', {}).get('uri')): return # if the post already exists in the datastore, and isn't linked to a document, # that's a (mediocre) heuristic that we bridged it if existing := ATProto.load(post_uri, remote=False): for copy in existing.get_copies(ATProto): _, coll, _ = parse_at_uri(copy) if coll == 'site.standard.document': break else: logger.warning(f'Deleting bskyPostRef {post_uri} that we already bridged') delete_id = f'{post_uri}#delete-{util.now().isoformat()}' delete_post_as1 = { 'objectType': 'activity', 'verb': 'delete', 'id': delete_id, 'actor': op.repo, 'object': post_uri, } create_task(queue='receive', id=delete_id, source_protocol=ATProto.LABEL, our_as1=delete_post_as1, authed_as=op.repo, received_at=op.time, delay=DELETE_TASK_DELAY) # link the bsky post to this doc @ndb.transactional() def add_doc_copy(): if not (post := ATProto.load(post_uri)): logger.warning(f'bskyPostRef {post_uri} not found!') return doc_uri = f'at://{op.repo}/{op.path}' logger.warning(f'Adding doc copy {doc_uri} to bskyPostRef {post_uri}') post.add('copies', Target(protocol='atproto', uri=doc_uri)) post.put() add_doc_copy() @ndb.transactional() def _handle_webMonetization(op): """ Args: op (Op) """ profile_uri = at_uri(op.repo, 'app.bsky.actor.profile', 'self') profile = Object.get_or_insert(profile_uri) if not profile.extra_as1: profile.extra_as1 = {} profile.extra_as1.update({'monetization': op.record['address']}) profile.put() def _encode_bytes_cids(val): """Convert bytes values in a record to base32-encoded string CIDs, in place. This is a compatibility hack for libipld v2+. v1's ``decode_car`` returned CIDs as strings, base32-encoded. v2 switched to raw bytes. https://github.com/snarfed/bridgy-fed/issues/1316 https://github.com/MarshalX/python-libipld/releases/tag/v2.0.0 We JSON-encode record and include them in task HTTP request bodies, so we need to encode these CIDs. One catch is that there are also bytes-valued fields in ATProto records, https://atproto.com/specs/lexicon#bytes . To do this right, we'd need to use the record's lexicon to introspect it determine each value's type, and do the right thing. That would be a lot of work, though. :meth:`lexrpc.Base.validate` has the logic, but it's not usable for arbitrary transformations like this. Fortunately, as of now (Oct 2025), the only bytes-valued field in the app.bsky lexicons is in subscribeLabels, and not in a record. So, as a hack, we try to convert all bytes values to CIDs, and if that fails, let the exception propagate up. (Should never happen right now.) Args: val (record dict, or other dict or list or primitive value) """ if isinstance(val, bytes): return CID.decode(val).encode('base32') # raises ValueError if not CID elif isinstance(val, dict): return {k: _encode_bytes_cids(v) for k, v in val.items()} elif isinstance(val, (list, tuple)): return [_encode_bytes_cids(v) for v in val] return val