go-zero/zrpc/internal/chainclientinterceptors.go
2020-09-18 11:41:52 +08:00

84 lines
2.7 KiB
Go

package internal
import (
"context"
"google.golang.org/grpc"
)
func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption {
return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...))
}
func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption {
return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...))
}
func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
switch len(interceptors) {
case 0:
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return streamer(ctx, desc, cc, method, opts...)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
var chainStreamer grpc.Streamer
var current int
chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn,
curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) {
if current == last {
return streamer(curCtx, curDesc, curCc, curMethod, curOpts...)
}
current++
clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...)
current--
return clientStream, err
}
return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...)
}
}
}
func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
switch len(interceptors) {
case 0:
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var chainInvoker grpc.UnaryInvoker
var current int
chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{},
curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error {
if current == last {
return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...)
}
current++
err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...)
current--
return err
}
return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...)
}
}
}