明星 CORS 排除端点

问题描述

我正在使用具有以下配置的 Fastapi(CORSMiddleware)

app.add_middleware(
    CORSMiddleware,allow_origins=['frontend.domain.com'],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],)

无论如何要排除某些端点或向端点添加信号以通过 CORS 检查并允许不是来自 frontend.domain.com 的请求?

谢谢。

解决方法

为了解决这个问题,我修改了原来的 CORSMiddleware

用法:

    app.add_middleware(
        CORSMiddleware,allow_origins=["allowed origins"],allow_credentials=True,allow_path_regex="(.*)\\/endpoint\\/[a-f0-9]{32}$",allow_methods=["*"],allow_headers=["*"],)

修改的中间件:

import functools
import re
import typing

from starlette.datastructures import Headers,MutableHeaders
from starlette.responses import PlainTextResponse,Response
from starlette.types import ASGIApp,Message,Receive,Scope,Send

ALL_METHODS = ("DELETE","GET","OPTIONS","PATCH","POST","PUT")
SAFELISTED_HEADERS = {"Accept","Accept-Language","Content-Language","Content-Type"}


class CORSMiddleware:
    """
    This middleware added allow_path_regex option to open some endpoints to the world
    """

    def __init__(
        self,app: ASGIApp,allow_origins: typing.Sequence[str] = (),allow_methods: typing.Sequence[str] = ("GET",),allow_headers: typing.Sequence[str] = (),allow_credentials: bool = False,allow_origin_regex: typing.Optional[str] = None,allow_path_regex: typing.Optional[str] = None,expose_headers: typing.Sequence[str] = (),max_age: int = 600,) -> None:

        if "*" in allow_methods:
            allow_methods = ALL_METHODS

        compiled_allow_origin_regex = None
        if allow_origin_regex is not None:
            compiled_allow_origin_regex = re.compile(allow_origin_regex)

        compiled_allow_path_regex = None
        if allow_path_regex:
            compiled_allow_path_regex = re.compile(allow_path_regex)

        simple_headers = {}
        if "*" in allow_origins:
            simple_headers["Access-Control-Allow-Origin"] = "*"
        if allow_credentials:
            simple_headers["Access-Control-Allow-Credentials"] = "true"
        if expose_headers:
            simple_headers["Access-Control-Expose-Headers"] = ",".join(expose_headers)

        preflight_headers = {}
        if "*" in allow_origins:
            preflight_headers["Access-Control-Allow-Origin"] = "*"
        else:
            preflight_headers["Vary"] = "Origin"
        preflight_headers.update(
            {
                "Access-Control-Allow-Methods": ",".join(allow_methods),"Access-Control-Max-Age": str(max_age),}
        )
        allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
        if allow_headers and "*" not in allow_headers:
            preflight_headers["Access-Control-Allow-Headers"] = ",".join(allow_headers)
        if allow_credentials:
            preflight_headers["Access-Control-Allow-Credentials"] = "true"

        self.app = app
        self.allow_origins = allow_origins
        self.allow_methods = allow_methods
        self.allow_headers = [h.lower() for h in allow_headers]
        self.allow_all_origins = "*" in allow_origins
        self.allow_all_headers = "*" in allow_headers
        self.allow_origin_regex = compiled_allow_origin_regex
        self.allow_path_regex = compiled_allow_path_regex
        self.simple_headers = simple_headers
        self.preflight_headers = preflight_headers

    async def __call__(self,scope: Scope,receive: Receive,send: Send) -> None:
        if scope["type"] != "http":  # pragma: no cover
            await self.app(scope,receive,send)
            return

        method = scope["method"]
        headers = Headers(scope=scope)
        origin = headers.get("origin")
        path = scope["path"]

        if origin is None:
            await self.app(scope,send)
            return

        if method == "OPTIONS" and "access-control-request-method" in headers:
            response = self.preflight_response(request_headers=headers,path=path)
            await response(scope,send)
            return

        await self.simple_response(scope,send,request_headers=headers,path=path)

    def is_allowed_origin(self,origin: str,path: str) -> bool:
        if self.allow_all_origins:
            return True

        if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
            return True

        if self.allow_path_regex is not None and self.allow_path_regex.fullmatch(path):
            return True

        return origin in self.allow_origins

    def preflight_response(self,request_headers: Headers,path: str) -> Response:
        requested_origin = request_headers["origin"]
        requested_method = request_headers["access-control-request-method"]
        requested_headers = request_headers.get("access-control-request-headers")

        headers = dict(self.preflight_headers)
        failures = []

        if self.is_allowed_origin(origin=requested_origin,path=path):
            if not self.allow_all_origins:
                # If self.allow_all_origins is True,then the "Access-Control-Allow-Origin"
                # header is already set to "*".
                # If we only allow specific origins,then we have to mirror back
                # the Origin header in the response.
                headers["Access-Control-Allow-Origin"] = requested_origin
        else:
            failures.append("origin")

        if requested_method not in self.allow_methods:
            failures.append("method")

        # If we allow all headers,then we have to mirror back any requested
        # headers in the response.
        if self.allow_all_headers and requested_headers is not None:
            headers["Access-Control-Allow-Headers"] = requested_headers
        elif requested_headers is not None:
            for header in [h.lower() for h in requested_headers.split(",")]:
                if header.strip() not in self.allow_headers:
                    failures.append("headers")

        # We don't strictly need to use 400 responses here,since its up to
        # the browser to enforce the CORS policy,but its more informative
        # if we do.
        if failures:
            failure_text = "Disallowed CORS " + ",".join(failures)
            return PlainTextResponse(failure_text,status_code=400,headers=headers)

        return PlainTextResponse("OK",status_code=200,headers=headers)

    async def simple_response(
        self,send: Send,path: str
    ) -> None:
        send = functools.partial(self.send,send=send,request_headers=request_headers,path=path)
        await self.app(scope,send)

    async def send(self,message: Message,path: str) -> None:
        if message["type"] != "http.response.start":
            await send(message)
            return

        message.setdefault("headers",[])
        headers = MutableHeaders(scope=message)
        headers.update(self.simple_headers)
        origin = request_headers["Origin"]
        has_cookie = "cookie" in request_headers

        # If request includes any cookie headers,then we must respond
        # with the specific origin instead of '*'.
        if self.allow_all_origins and has_cookie:
            headers["Access-Control-Allow-Origin"] = origin

        # If we only allow specific origins,then we have to mirror back
        # the Origin header in the response.
        elif not self.allow_all_origins and self.is_allowed_origin(origin=origin,path=path):
            headers["Access-Control-Allow-Origin"] = origin
            headers.add_vary_header("Origin")
        await send(message)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...