refactor: simplify BatchError (#4292)

This commit is contained in:
Kevin Wan 2024-08-03 13:57:41 +08:00 committed by GitHub
parent c6348b9855
commit dedba17219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 69 additions and 38 deletions

View File

@ -52,7 +52,7 @@ jobs:
- name: Set up Go 1.x
uses: actions/setup-go@v5
with:
# use 1.20 to guarantee Go 1.20 compatibility
# make sure Go version compatible with go-zero
go-version-file: go.mod
check-latest: true
cache: true

View File

@ -5,17 +5,13 @@ import (
"sync"
)
type (
// A BatchError is an error that can hold multiple errors.
BatchError struct {
errs errorArray
lock sync.Mutex
}
// BatchError is an error that can hold multiple errors.
type BatchError struct {
errs []error
lock sync.RWMutex
}
errorArray []error
)
// Add adds errs to be, nil errors are ignored.
// Add adds one or more non-nil errors to the BatchError instance.
func (be *BatchError) Add(errs ...error) {
be.lock.Lock()
defer be.lock.Unlock()
@ -27,35 +23,20 @@ func (be *BatchError) Add(errs ...error) {
}
}
// Err returns an error that represents all errors.
// Err returns an error that represents all accumulated errors.
// It returns nil if there are no errors.
func (be *BatchError) Err() error {
be.lock.Lock()
defer be.lock.Unlock()
be.lock.RLock()
defer be.lock.RUnlock()
switch len(be.errs) {
case 0:
return nil
case 1:
return be.errs[0]
default:
return be.errs
}
// If there are no non-nil errors, errors.Join(...) returns nil.
return errors.Join(be.errs...)
}
// NotNil checks if any error inside.
// NotNil checks if there is at least one error inside the BatchError.
func (be *BatchError) NotNil() bool {
be.lock.Lock()
defer be.lock.Unlock()
be.lock.RLock()
defer be.lock.RUnlock()
return len(be.errs) > 0
}
// Error returns a string that represents inside errors.
func (ea errorArray) Error() string {
return errors.Join(ea...).Error()
}
// Unwrap combine the errors in the errorArray into a single error return
func (ea errorArray) Unwrap() error {
return errors.Join(ea...)
}

View File

@ -95,3 +95,53 @@ func TestBatchError_Unwrap(t *testing.T) {
assert.False(t, errors.Is(be.Err(), errBaz))
})
}
func TestBatchError_Add(t *testing.T) {
var be BatchError
// Test adding nil errors
be.Add(nil, nil)
assert.False(t, be.NotNil(), "Expected BatchError to be empty after adding nil errors")
// Test adding non-nil errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding errors")
// Test adding a mix of nil and non-nil errors
err3 := errors.New("error 3")
be.Add(nil, err3, nil)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding a mix of nil and non-nil errors")
}
func TestBatchError_Err(t *testing.T) {
var be BatchError
// Test Err() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test Err() with multiple errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
combinedErr := be.Err()
assert.NotNil(t, combinedErr, "Expected nil error for BatchError with multiple errors")
// Check if the combined error contains both error messages
errString := combinedErr.Error()
assert.Truef(t, errors.Is(combinedErr, err1), "Combined error doesn't contain first error: %s", errString)
assert.Truef(t, errors.Is(combinedErr, err2), "Combined error doesn't contain second error: %s", errString)
}
func TestBatchError_NotNil(t *testing.T) {
var be BatchError
// Test NotNil() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test NotNil() after adding an error
be.Add(errors.New("test error"))
assert.NotNil(t, be.Err(), "Expected non-nil error after adding an error")
}

View File

@ -94,4 +94,4 @@ else
echo "a diff found"
execute_command "diff -r $OLD_CODE $NEW_CODE"
exit 1
fi
fi

View File

@ -62,4 +62,4 @@ service Test_Service {
rpc ClientStream (stream Req) returns (Reply);
// stream
rpc Stream(stream Req) returns (stream Reply);
}
}

View File

@ -62,4 +62,4 @@ service Test_Service {
rpc ClientStream (stream Req) returns (Reply);
// stream
rpc Stream(stream Req) returns (stream Reply);
}
}