mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 00:50:20 +08:00
chore: refactor zrpc timeout (#3671)
This commit is contained in:
parent
842c4d81cc
commit
922efbfc2d
@ -111,7 +111,7 @@ func SetClientSlowThreshold(threshold time.Duration) {
|
||||
clientinterceptors.SetSlowThreshold(threshold)
|
||||
}
|
||||
|
||||
// WithTimeoutCallOption return a call option with given timeout.
|
||||
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
|
||||
return clientinterceptors.WithTimeoutCallOption(timeout)
|
||||
// WithCallTimeout return a call option with given timeout to make a method call.
|
||||
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
|
||||
return clientinterceptors.WithCallTimeout(timeout)
|
||||
}
|
||||
|
@ -41,12 +41,12 @@ func dialer() func(context.Context, string) (net.Conn, error) {
|
||||
|
||||
func TestDepositServer_Deposit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
amount float32
|
||||
timeoutCallOption time.Duration
|
||||
res *mock.DepositResponse
|
||||
errCode codes.Code
|
||||
errMsg string
|
||||
name string
|
||||
amount float32
|
||||
timeout time.Duration
|
||||
res *mock.DepositResponse
|
||||
errCode codes.Code
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "invalid request with negative amount",
|
||||
@ -66,12 +66,12 @@ func TestDepositServer_Deposit(t *testing.T) {
|
||||
errMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "valid request with timeout call option",
|
||||
amount: 2000.00,
|
||||
timeoutCallOption: time.Second * 3,
|
||||
res: &mock.DepositResponse{Ok: true},
|
||||
errCode: codes.OK,
|
||||
errMsg: "",
|
||||
name: "valid request with timeout call option",
|
||||
amount: 2000.00,
|
||||
timeout: time.Second * 3,
|
||||
res: &mock.DepositResponse{Ok: true},
|
||||
errCode: codes.OK,
|
||||
errMsg: "",
|
||||
},
|
||||
}
|
||||
|
||||
@ -171,8 +171,8 @@ func TestDepositServer_Deposit(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
if tt.timeoutCallOption > 0 {
|
||||
response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption))
|
||||
if tt.timeout > 0 {
|
||||
response, err = cli.Deposit(ctx, request, WithCallTimeout(tt.timeout))
|
||||
} else {
|
||||
response, err = cli.Deposit(ctx, request)
|
||||
}
|
||||
|
@ -17,8 +17,8 @@ type (
|
||||
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
|
||||
// StatConf defines the stat config.
|
||||
StatConf = internal.StatConf
|
||||
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
|
||||
ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf
|
||||
// MethodTimeoutConf defines specified timeout for gRPC method.
|
||||
MethodTimeoutConf = internal.MethodTimeoutConf
|
||||
|
||||
// A RpcClientConf is a rpc client config.
|
||||
RpcClientConf struct {
|
||||
@ -48,7 +48,7 @@ type (
|
||||
Health bool `json:",default=true"`
|
||||
Middlewares ServerMiddlewaresConf
|
||||
// setting specified timeout for gRPC method
|
||||
SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"`
|
||||
MethodTimeouts []MethodTimeoutConf `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -7,11 +7,17 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// TimeoutCallOption is a call option that controls timeout.
|
||||
type TimeoutCallOption struct {
|
||||
grpc.EmptyCallOption
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// TimeoutInterceptor is an interceptor that controls timeout.
|
||||
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
|
||||
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
t := getTimeoutByCallOptions(opts, timeout)
|
||||
t := getTimeoutFromCallOptions(opts, timeout)
|
||||
if t <= 0 {
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
@ -23,24 +29,19 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
|
||||
}
|
||||
}
|
||||
|
||||
func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
|
||||
for _, callOption := range callOptions {
|
||||
if o, ok := callOption.(TimeoutCallOption); ok {
|
||||
// WithCallTimeout returns a call option that controls method call timeout.
|
||||
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
|
||||
return TimeoutCallOption{
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func getTimeoutFromCallOptions(opts []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
|
||||
for _, opt := range opts {
|
||||
if o, ok := opt.(TimeoutCallOption); ok {
|
||||
return o.timeout
|
||||
}
|
||||
}
|
||||
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
type TimeoutCallOption struct {
|
||||
grpc.EmptyCallOption
|
||||
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
|
||||
return TimeoutCallOption{
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
|
||||
cc := new(grpc.ClientConn)
|
||||
var co []grpc.CallOption
|
||||
if tt.args.callOptionTimeout > 0 {
|
||||
co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout))
|
||||
co = append(co, WithCallTimeout(tt.args.callOptionTimeout))
|
||||
}
|
||||
|
||||
err := interceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
|
@ -25,5 +25,6 @@ type (
|
||||
Breaker bool `json:",default=true"`
|
||||
}
|
||||
|
||||
ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf
|
||||
// MethodTimeoutConf defines specified timeout for gRPC methods.
|
||||
MethodTimeoutConf = serverinterceptors.MethodTimeoutConf
|
||||
)
|
||||
|
@ -15,21 +15,22 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
|
||||
ServerSpecifiedTimeoutConf struct {
|
||||
// MethodTimeoutConf defines specified timeout for gRPC method.
|
||||
MethodTimeoutConf struct {
|
||||
FullMethod string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
specifiedTimeoutCache map[string]time.Duration
|
||||
methodTimeouts map[string]time.Duration
|
||||
)
|
||||
|
||||
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
|
||||
func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
|
||||
cache := cacheSpecifiedTimeout(specifiedTimeouts)
|
||||
func UnaryTimeoutInterceptor(timeout time.Duration,
|
||||
methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor {
|
||||
timeouts := buildMethodTimeouts(methodTimeouts)
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (any, error) {
|
||||
t := getTimeoutByUnaryServerInfo(info, timeout, cache)
|
||||
t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, t)
|
||||
defer cancel()
|
||||
|
||||
@ -72,27 +73,22 @@ func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerS
|
||||
}
|
||||
}
|
||||
|
||||
func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
|
||||
cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
|
||||
for _, st := range specifiedTimeouts {
|
||||
func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts {
|
||||
mt := make(methodTimeouts, len(timeouts))
|
||||
for _, st := range timeouts {
|
||||
if st.FullMethod != "" {
|
||||
cache[st.FullMethod] = st.Timeout
|
||||
mt[st.FullMethod] = st.Timeout
|
||||
}
|
||||
}
|
||||
|
||||
return cache
|
||||
return mt
|
||||
}
|
||||
|
||||
func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
|
||||
if ts, ok := info.Server.(TimeoutStrategy); ok {
|
||||
return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
|
||||
} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
|
||||
func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts,
|
||||
defaultTimeout time.Duration) time.Duration {
|
||||
if v, ok := timeouts[method]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
type TimeoutStrategy interface {
|
||||
GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
|
||||
}
|
||||
|
@ -103,13 +103,6 @@ type tempServer struct {
|
||||
func (s *tempServer) run(duration time.Duration) {
|
||||
time.Sleep(duration)
|
||||
}
|
||||
func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
|
||||
if fullMethod == "/" {
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
return s.timeout
|
||||
}
|
||||
|
||||
func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
|
||||
type args struct {
|
||||
@ -136,17 +129,6 @@ func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "do not timeout with timeout strategy",
|
||||
args: args{
|
||||
interceptorTimeout: time.Second,
|
||||
contextTimeout: time.Second * 5,
|
||||
serverTimeout: time.Second * 3,
|
||||
runTime: time.Second * 2,
|
||||
fullMethod: "/2s",
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "timeout with interceptor timeout",
|
||||
args: args{
|
||||
@ -235,9 +217,9 @@ func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var specifiedTimeouts []ServerSpecifiedTimeoutConf
|
||||
var specifiedTimeouts []MethodTimeoutConf
|
||||
if tt.args.methodTimeout > 0 {
|
||||
specifiedTimeouts = []ServerSpecifiedTimeoutConf{
|
||||
specifiedTimeouts = []MethodTimeoutConf{
|
||||
{
|
||||
FullMethod: tt.args.method,
|
||||
Timeout: tt.args.methodTimeout,
|
||||
|
@ -131,12 +131,8 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri
|
||||
}
|
||||
|
||||
if c.Timeout > 0 {
|
||||
svr.AddUnaryInterceptors(
|
||||
serverinterceptors.UnaryTimeoutInterceptor(
|
||||
time.Duration(c.Timeout)*time.Millisecond,
|
||||
c.SpecifiedTimeouts...,
|
||||
),
|
||||
)
|
||||
svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor(
|
||||
time.Duration(c.Timeout)*time.Millisecond, c.MethodTimeouts...))
|
||||
}
|
||||
|
||||
if c.Auth {
|
||||
|
@ -40,7 +40,7 @@ func TestServer_setupInterceptors(t *testing.T) {
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
},
|
||||
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
|
||||
MethodTimeouts: []MethodTimeoutConf{
|
||||
{
|
||||
FullMethod: "/foo",
|
||||
Timeout: 5 * time.Second,
|
||||
@ -81,7 +81,7 @@ func TestServer(t *testing.T) {
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
},
|
||||
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
|
||||
MethodTimeouts: []MethodTimeoutConf{
|
||||
{
|
||||
FullMethod: "/foo",
|
||||
Timeout: time.Second,
|
||||
@ -117,7 +117,7 @@ func TestServerError(t *testing.T) {
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
},
|
||||
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
|
||||
MethodTimeouts: []MethodTimeoutConf{},
|
||||
}, func(server *grpc.Server) {
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
@ -144,7 +144,7 @@ func TestServer_HasEtcd(t *testing.T) {
|
||||
Prometheus: true,
|
||||
Breaker: true,
|
||||
},
|
||||
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
|
||||
MethodTimeouts: []MethodTimeoutConf{},
|
||||
}, func(server *grpc.Server) {
|
||||
})
|
||||
svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
|
||||
|
Loading…
Reference in New Issue
Block a user