go-zero/zrpc/internal/serverinterceptors/authinterceptor_test.go

198 lines
4.2 KiB
Go
Raw Normal View History

2020-08-22 23:08:33 +08:00
package serverinterceptors
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
"github.com/zeromicro/go-zero/zrpc/internal/auth"
2020-08-23 15:53:10 +08:00
"google.golang.org/grpc"
2020-08-22 23:08:33 +08:00
"google.golang.org/grpc/metadata"
)
2020-08-23 15:53:10 +08:00
func TestStreamAuthorizeInterceptor(t *testing.T) {
tests := []struct {
name string
app string
token string
strict bool
hasError bool
}{
{
name: "strict=false",
strict: false,
hasError: false,
},
{
name: "strict=true",
strict: true,
hasError: true,
},
{
name: "strict=true,with token",
app: "foo",
token: "bar",
strict: true,
hasError: false,
},
{
name: "strict=true,with error token",
app: "foo",
token: "error",
strict: true,
hasError: true,
},
}
2020-11-02 17:51:33 +08:00
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
2020-08-23 15:53:10 +08:00
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if len(test.app) > 0 {
assert.Nil(t, store.Hset("apps", test.app, test.token))
defer store.Hdel("apps", test.app)
}
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
assert.Nil(t, err)
interceptor := StreamAuthorizeInterceptor(authenticator)
md := metadata.New(map[string]string{
"app": "foo",
"token": "bar",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
stream := mockedStream{ctx: ctx}
err = interceptor(nil, stream, nil, func(_ interface{}, _ grpc.ServerStream) error {
2020-08-23 15:53:10 +08:00
return nil
})
if test.hasError {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
})
}
}
2020-08-22 23:08:33 +08:00
func TestUnaryAuthorizeInterceptor(t *testing.T) {
tests := []struct {
2020-08-23 15:53:10 +08:00
name string
app string
token string
strict bool
hasError bool
2020-08-22 23:08:33 +08:00
}{
{
2020-08-23 15:53:10 +08:00
name: "strict=false",
strict: false,
hasError: false,
},
{
name: "strict=true",
strict: true,
hasError: true,
},
{
name: "strict=true,with token",
app: "foo",
token: "bar",
strict: true,
hasError: false,
2020-08-22 23:08:33 +08:00
},
{
2020-08-23 15:53:10 +08:00
name: "strict=true,with error token",
app: "foo",
token: "error",
strict: true,
hasError: true,
2020-08-22 23:08:33 +08:00
},
}
2020-11-02 17:51:33 +08:00
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
2020-08-22 23:08:33 +08:00
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
2020-08-23 15:53:10 +08:00
if len(test.app) > 0 {
assert.Nil(t, store.Hset("apps", test.app, test.token))
defer store.Hdel("apps", test.app)
}
2020-08-22 23:08:33 +08:00
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
assert.Nil(t, err)
interceptor := UnaryAuthorizeInterceptor(authenticator)
md := metadata.New(map[string]string{
2020-08-23 15:53:10 +08:00
"app": "foo",
"token": "bar",
2020-08-22 23:08:33 +08:00
})
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
2020-08-23 15:53:10 +08:00
if test.hasError {
2020-08-22 23:08:33 +08:00
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
2020-08-23 15:53:10 +08:00
if test.strict {
_, err = interceptor(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
var md metadata.MD
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
md = metadata.New(map[string]string{
"app": "",
"token": "",
})
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
}
2020-08-22 23:08:33 +08:00
})
}
}
2020-08-23 15:53:10 +08:00
type mockedStream struct {
ctx context.Context
}
func (m mockedStream) SetHeader(md metadata.MD) error {
return nil
}
func (m mockedStream) SendHeader(md metadata.MD) error {
return nil
}
func (m mockedStream) SetTrailer(md metadata.MD) {
}
func (m mockedStream) Context() context.Context {
return m.ctx
}
func (m mockedStream) SendMsg(v interface{}) error {
return nil
}
func (m mockedStream) RecvMsg(v interface{}) error {
return nil
}