如何检查服务器中间件中的截止日期

问题描述

根据 doc,gRPC 服务器可以检查截止日期为

if ctx.Err() == context.Canceled {
    return status.New(codes.Canceled,"Client cancelled,abandoning.")
}

我尝试在 hello 示例中处理它

// SayHello implements helloworld.GreeterServer
func (s *server) SayHello(ctx context.Context,in *pb.HelloRequest) (*pb.HelloReply,error) {
    log.Printf("Received: %v",in.GetName())

    select {
    case <- time.After(time.Second):
        // simulate some operation here
    case <- ctx.Done():
        if ctx.Err() == context.Canceled || ctx.Err() == context.DeadlineExceeded{
            return nil,status.New(codes.Canceled,abandoning.").Err()
        }
    }

    return &pb.HelloReply{Message: "Hello " + in.GetName()},nil
}

使用测试客户端代码

    ctx,cancel := context.WithTimeout(context.Background(),time.Millisecond*600)
    defer cancel()
    r,err := c.SayHello(ctx,&pb.HelloRequest{Name: name})

我的问题:如何在 gRPC 服务器中检查中间件(或拦截器)的截止日期?

我所尝试的

type serverStats struct {}

func (h *serverStats) TagRPC(ctx context.Context,info *stats.RPCTagInfo) context.Context {
    fmt.Println("tag rpc")
    return ctx
}

func (h *serverStats) HandleRPC(ctx context.Context,s stats.RPCStats) {
    fmt.Println("handle rpc")
    if ctx.Err() == context.Canceled || ctx.Err() == context.DeadlineExceeded {
        fmt.Printf("HandleRPC: Client err %+v \n",ctx.Err())
    }
}

...

    s := grpc.NewServer(grpc.StatsHandler(&serverStats{}))

然而,HandleRPC 被多次触发,并且它不能返回错误状态代码

解决方法

我以 hello 示例为基础编写了一个超时拦截器示例。

可能有更好的方法来做到这一点,但它是这样的:

/*
 *
 * Copyright 2015 gRPC authors.
 *
 * Licensed under the Apache License,Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,software
 * distributed under the License is distributed on an "AS IS" BASIS,* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

// Package main implements a server for Greeter service.
package main

import (
    "context"
    "log"
    "net"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    pb "google.golang.org/grpc/examples/helloworld/helloworld"
    "google.golang.org/grpc/status"
)

const (
    port = ":50051"
)

func timeoutInterceptor(ctx context.Context,req interface{},info *grpc.UnaryServerInfo,handler grpc.UnaryHandler) (interface{},error) {
    var err error
    var result interface{}

    done := make(chan struct{})

    go func() {
        result,err = handler(ctx,req)
        done <- struct{}{}
    }()

    select {
    case <-ctx.Done():
        if ctx.Err() == context.Canceled || ctx.Err() == context.DeadlineExceeded {
            return nil,status.New(codes.Canceled,"Client cancelled,abandoning.").Err()
        }
    case <-done:
    }
    return result,err
}

// server is used to implement helloworld.GreeterServer.
type server struct {
    pb.UnimplementedGreeterServer
}

// SayHello implements helloworld.GreeterServer
func (s *server) SayHello(ctx context.Context,in *pb.HelloRequest) (*pb.HelloReply,error) {
    log.Printf("Received: %v",in.GetName())
    return &pb.HelloReply{Message: "Hello " + in.GetName()},nil
}

func main() {
    lis,err := net.Listen("tcp",port)
    if err != nil {
        log.Fatalf("failed to listen: %v",err)
    }
    s := grpc.NewServer(grpc.ChainUnaryInterceptor(timeoutInterceptor))
    pb.RegisterGreeterServer(s,&server{})
    if err := s.Serve(lis); err != nil {
        log.Fatalf("failed to serve: %v",err)
    }
}
,

您可以使用此 library

拦截器有请求上下文信息。

而且我认为您不需要在服务器端自己处理截止日期(尤其是在一元请求中),4 DEADLINE_EXCEEDED 错误是由 grpc 库处理的。您只需要确保客户设置了截止日期,并且截止日期在可接受的范围内。

func UnaryInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context,error) {
        //check whether client set the deadline
        d,ok := ctx.Deadline()
        if !ok {
            //return some error
        }
        timeout := d.Sub(time.Now())
        //check timeout range
        if timeout < 5*time.Second || timeout > 30*time.Second {
            //return some error
        }
}