Skip to content

WebSocket Views

Class-based WebSocket endpoint. Import from fastapi_views.views.websockets.

For a complete walkthrough see WebSockets.


WebSocketAPIView

Bases: DependencyMixin, ABC, Generic[RecvT, SendT]

Base class for implementing WebSocket views

Source code in fastapi_views/views/websockets.py
class WebSocketAPIView(DependencyMixin, ABC, Generic[RecvT, SendT]):
    """
    Base class for implementing WebSocket views
    """

    logger: Logger
    name: str
    message_schema: type[RecvT] | None = None
    default_serializer_options: ClassVar[SerializerOptions] = {
        "by_alias": True,
    }
    validate_on_send: bool = True

    disconnect_timeout: int = 30
    _serializers: ClassVar[TypeAdapterMap] = {}
    _connections: ClassVar[list[WebSocket]]

    def __init_subclass__(cls) -> None:
        cls._connections = []
        cls.logger = getLogger(f"{cls.__module__}:{cls.get_name()}")

    def __init__(self, websocket: WebSocket) -> None:
        self.websocket = websocket
        self.validation_context = None
        self.serializer_options = self.default_serializer_options.copy()
        self._snd, self._rcv = create_memory_object_stream()

    @classmethod
    def get_name(cls) -> str:
        return getattr(cls, "name", cls.__name__)

    @classmethod
    def get_message_schema(cls, action: WebSocketAction) -> Any:  # noqa: ARG003
        return cls.message_schema

    @overload
    @classmethod
    def get_serializer(cls, action: Literal["send"]) -> TypeAdapter[SendT]: ...

    @overload
    @classmethod
    def get_serializer(cls, action: Literal["receive"]) -> TypeAdapter[RecvT]: ...

    @classmethod
    def get_serializer(
        cls, action: WebSocketAction
    ) -> TypeAdapter[RecvT] | TypeAdapter[SendT]:
        schema = cls.get_message_schema(action)
        if schema is None:
            return AnyTypeAdapter
        if schema not in cls._serializers:
            cls._serializers[schema] = TypeAdapter(schema)
        return cls._serializers[schema]

    async def _receiver(self, cancel_scope: CancelScope) -> None:
        serializer = self.get_serializer("receive")
        try:
            async with self._snd:
                while True:
                    data = await self.websocket.receive_bytes()
                    message = serializer.validate_json(
                        data, context=self.validation_context
                    )
                    await self._snd.send(message)
        except (ValidationError, WebSocketDisconnect) as e:
            self.logger.warning(
                "Exception while receiving data from websocket", exc_info=e
            )
            cancel_scope.cancel()

    async def _handler(
        self, fn: Callable[[], Awaitable[None]], cancel_scope: CancelScope
    ) -> None:
        try:
            async with self._rcv:
                await fn()
        finally:
            cancel_scope.cancel()

    @classmethod
    def get_websocket_endpoint(cls) -> Callable:

        async def endpoint(self: WebSocketAPIView, *args: Any, **kwargs: Any) -> None:
            try:
                await self.websocket.accept()
                self._connections.append(self.websocket)
                await self.on_connect()
                fn = functools.partial(self.handler, *args, **kwargs)
                async with create_task_group() as tg:
                    tg.start_soon(self._receiver, tg.cancel_scope)
                    tg.start_soon(self._handler, fn, tg.cancel_scope)
            finally:
                with fail_after(self.disconnect_timeout, shield=True):
                    self._connections.remove(self.websocket)
                    await self.websocket.close()
                    await self.on_disconnect()

        cls._patch_endpoint_signature(endpoint, cls.handler)
        return endpoint

    @classmethod
    def get_websocket_action(cls, prefix: str = "") -> dict[str, Any]:
        endpoint = cls.get_websocket_endpoint()
        return {
            "path": prefix,
            "endpoint": endpoint,
            "name": cls.get_name(),
        }

    def _serialize_message(self, obj: Any) -> bytes:
        serializer = self.get_serializer("send")
        if self.validate_on_send:
            obj = serializer.validate_python(obj)
            return serializer.dump_json(obj, **self.serializer_options)
        return serializer.dump_json(obj, warnings=False, **self.serializer_options)

    async def _safe_send(self, data: bytes, websocket: WebSocket | None = None) -> None:
        websocket = websocket or self.websocket
        try:
            await websocket.send_bytes(data)
        except (WebSocketDisconnect, ClosedResourceError) as e:
            self.logger.warning("Error sending bytes to websocket", exc_info=e)

    async def send(self, message: SendT) -> None:
        data = self._serialize_message(message)
        await self._safe_send(data)

    async def broadcast(self, message: SendT) -> None:
        data = self._serialize_message(message)
        async with create_task_group() as tg:
            for connection in self._connections:
                tg.start_soon(self._safe_send, data, connection)

    @property
    def messages(self) -> AsyncIterable[RecvT]:
        return self._rcv

    async def on_connect(self) -> None:
        pass

    async def on_disconnect(self) -> None:
        pass

    @abstractmethod
    async def handler(self, *args: Any, **kwargs: Any) -> None:
        raise NotImplementedError