diff --git a/core/codec/aesecb.go b/core/codec/aesecb.go index 08a388fb..1ec99a04 100644 --- a/core/codec/aesecb.go +++ b/core/codec/aesecb.go @@ -32,9 +32,11 @@ func NewECBEncrypter(b cipher.Block) cipher.BlockMode { return (*ecbEncrypter)(newECB(b)) } +// BlockSize returns the mode's block size. func (x *ecbEncrypter) BlockSize() int { return x.blockSize } -// why we don't return error is because cipher.BlockMode doesn't allow this +// CryptBlocks encrypts a number of blocks. The length of src must be a multiple of +// the block size. Dst and src must overlap entirely or not at all. func (x *ecbEncrypter) CryptBlocks(dst, src []byte) { if len(src)%x.blockSize != 0 { logx.Error("crypto/cipher: input not full blocks") @@ -59,11 +61,13 @@ func NewECBDecrypter(b cipher.Block) cipher.BlockMode { return (*ecbDecrypter)(newECB(b)) } +// BlockSize returns the mode's block size. func (x *ecbDecrypter) BlockSize() int { return x.blockSize } -// why we don't return error is because cipher.BlockMode doesn't allow this +// CryptBlocks decrypts a number of blocks. The length of src must be a multiple of +// the block size. Dst and src must overlap entirely or not at all. func (x *ecbDecrypter) CryptBlocks(dst, src []byte) { if len(src)%x.blockSize != 0 { logx.Error("crypto/cipher: input not full blocks") diff --git a/core/codec/aesecb_test.go b/core/codec/aesecb_test.go index 46a93ad7..a1117f3a 100644 --- a/core/codec/aesecb_test.go +++ b/core/codec/aesecb_test.go @@ -1,6 +1,7 @@ package codec import ( + "crypto/aes" "encoding/base64" "testing" @@ -10,7 +11,8 @@ import ( func TestAesEcb(t *testing.T) { var ( key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D") - val = []byte("hello") + val = []byte("helloworld") + valLong = []byte("helloworldlong..") badKey1 = []byte("aaaaaaaaa") // more than 32 chars badKey2 = []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") @@ -31,6 +33,39 @@ func TestAesEcb(t *testing.T) { src, err := EcbDecrypt(key, dst) assert.Nil(t, err) assert.Equal(t, val, src) + block, err := aes.NewCipher(key) + assert.NoError(t, err) + encrypter := NewECBEncrypter(block) + assert.Equal(t, 16, encrypter.BlockSize()) + decrypter := NewECBDecrypter(block) + assert.Equal(t, 16, decrypter.BlockSize()) + + dst = make([]byte, 8) + encrypter.CryptBlocks(dst, val) + for _, b := range dst { + assert.Equal(t, byte(0), b) + } + + dst = make([]byte, 8) + encrypter.CryptBlocks(dst, valLong) + for _, b := range dst { + assert.Equal(t, byte(0), b) + } + + dst = make([]byte, 8) + decrypter.CryptBlocks(dst, val) + for _, b := range dst { + assert.Equal(t, byte(0), b) + } + + dst = make([]byte, 8) + decrypter.CryptBlocks(dst, valLong) + for _, b := range dst { + assert.Equal(t, byte(0), b) + } + + _, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=") + assert.Error(t, err) } func TestAesEcbBase64(t *testing.T) { diff --git a/core/codec/dh_test.go b/core/codec/dh_test.go index 951ff486..9f788b44 100644 --- a/core/codec/dh_test.go +++ b/core/codec/dh_test.go @@ -80,3 +80,17 @@ func TestKeyBytes(t *testing.T) { assert.Nil(t, err) assert.True(t, len(key.Bytes()) > 0) } + +func TestDHOnErrors(t *testing.T) { + key, err := GenerateKey() + assert.Nil(t, err) + assert.NotEmpty(t, key.Bytes()) + _, err = ComputeKey(key.PubKey, key.PriKey) + assert.NoError(t, err) + _, err = ComputeKey(nil, key.PriKey) + assert.Error(t, err) + _, err = ComputeKey(key.PubKey, nil) + assert.Error(t, err) + + assert.NotNil(t, NewPublicKey([]byte(""))) +}