mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-02 16:28:39 +08:00
chore: optimize string search with Aho–Corasick algorithm (#1476)
* chore: optimize string search with Aho–Corasick algorithm * chore: optimize keywords replacer * fix: replacer bugs * chore: reorder members
This commit is contained in:
parent
09d1fad6e0
commit
f1102fb262
@ -2,6 +2,8 @@ package stringx
|
||||
|
||||
type node struct {
|
||||
children map[rune]*node
|
||||
fail *node
|
||||
depth int
|
||||
end bool
|
||||
}
|
||||
|
||||
@ -12,17 +14,19 @@ func (n *node) add(word string) {
|
||||
}
|
||||
|
||||
nd := n
|
||||
for _, char := range chars {
|
||||
var depth int
|
||||
for i, char := range chars {
|
||||
if nd.children == nil {
|
||||
child := new(node)
|
||||
nd.children = map[rune]*node{
|
||||
char: child,
|
||||
}
|
||||
child.depth = i + 1
|
||||
nd.children = map[rune]*node{char: child}
|
||||
nd = child
|
||||
} else if child, ok := nd.children[char]; ok {
|
||||
nd = child
|
||||
depth++
|
||||
} else {
|
||||
child := new(node)
|
||||
child.depth = i + 1
|
||||
nd.children[char] = child
|
||||
nd = child
|
||||
}
|
||||
@ -30,3 +34,68 @@ func (n *node) add(word string) {
|
||||
|
||||
nd.end = true
|
||||
}
|
||||
|
||||
func (n *node) build() {
|
||||
n.fail = n
|
||||
for _, child := range n.children {
|
||||
child.fail = n
|
||||
n.buildNode(child)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *node) buildNode(nd *node) {
|
||||
if nd.children == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var fifo []*node
|
||||
for key, child := range nd.children {
|
||||
fifo = append(fifo, child)
|
||||
|
||||
if fail, ok := nd.fail.children[key]; ok {
|
||||
child.fail = fail
|
||||
} else {
|
||||
child.fail = n
|
||||
}
|
||||
}
|
||||
|
||||
for _, val := range fifo {
|
||||
n.buildNode(val)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *node) find(chars []rune) []scope {
|
||||
var scopes []scope
|
||||
size := len(chars)
|
||||
cur := n
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
child, ok := cur.children[chars[i]]
|
||||
if ok {
|
||||
cur = child
|
||||
} else if cur == n {
|
||||
continue
|
||||
} else {
|
||||
cur = cur.fail
|
||||
if child, ok = cur.children[chars[i]]; !ok {
|
||||
continue
|
||||
}
|
||||
cur = child
|
||||
}
|
||||
|
||||
if child.end {
|
||||
scopes = append(scopes, scope{
|
||||
start: i + 1 - child.depth,
|
||||
stop: i + 1,
|
||||
})
|
||||
}
|
||||
if child.fail != n && child.fail.end {
|
||||
scopes = append(scopes, scope{
|
||||
start: i + 1 - child.fail.depth,
|
||||
stop: i + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return scopes
|
||||
}
|
||||
|
25
core/stringx/node_test.go
Normal file
25
core/stringx/node_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package stringx
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkNodeFind(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
keywords := []string{
|
||||
"A",
|
||||
"AV",
|
||||
"AV演员",
|
||||
"无名氏",
|
||||
"AV演员色情",
|
||||
"日本AV女优",
|
||||
}
|
||||
trie := new(node)
|
||||
for _, keyword := range keywords {
|
||||
trie.add(keyword)
|
||||
}
|
||||
trie.build()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
trie.find([]rune("日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演"))
|
||||
}
|
||||
}
|
@ -9,7 +9,7 @@ type (
|
||||
}
|
||||
|
||||
replacer struct {
|
||||
node
|
||||
*node
|
||||
mapping map[string]string
|
||||
}
|
||||
)
|
||||
@ -17,58 +17,81 @@ type (
|
||||
// NewReplacer returns a Replacer.
|
||||
func NewReplacer(mapping map[string]string) Replacer {
|
||||
rep := &replacer{
|
||||
node: new(node),
|
||||
mapping: mapping,
|
||||
}
|
||||
for k := range mapping {
|
||||
rep.add(k)
|
||||
}
|
||||
rep.build()
|
||||
|
||||
return rep
|
||||
}
|
||||
|
||||
// Replace replaces text with given substitutes.
|
||||
func (r *replacer) Replace(text string) string {
|
||||
var builder strings.Builder
|
||||
var start int
|
||||
chars := []rune(text)
|
||||
size := len(chars)
|
||||
start := -1
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
child, ok := r.children[chars[i]]
|
||||
if !ok {
|
||||
builder.WriteRune(chars[i])
|
||||
continue
|
||||
for start < size {
|
||||
cur := r.node
|
||||
|
||||
if start > 0 {
|
||||
builder.WriteString(string(chars[:start]))
|
||||
}
|
||||
|
||||
if start < 0 {
|
||||
start = i
|
||||
}
|
||||
end := -1
|
||||
if child.end {
|
||||
end = i + 1
|
||||
}
|
||||
for i := start; i < size; i++ {
|
||||
child, ok := cur.children[chars[i]]
|
||||
if ok {
|
||||
cur = child
|
||||
} else if cur == r.node {
|
||||
builder.WriteRune(chars[i])
|
||||
// cur already points to root, set start only
|
||||
start = i + 1
|
||||
continue
|
||||
} else {
|
||||
curDepth := cur.depth
|
||||
cur = cur.fail
|
||||
child, ok = cur.children[chars[i]]
|
||||
if !ok {
|
||||
// write this path
|
||||
builder.WriteString(string(chars[i-curDepth : i+1]))
|
||||
// go to root
|
||||
cur = r.node
|
||||
start = i + 1
|
||||
continue
|
||||
}
|
||||
|
||||
j := i + 1
|
||||
for ; j < size; j++ {
|
||||
grandchild, ok := child.children[chars[j]]
|
||||
if !ok {
|
||||
failDepth := cur.depth
|
||||
// write path before jump
|
||||
builder.WriteString(string(chars[start : start+curDepth-failDepth]))
|
||||
start += curDepth - failDepth
|
||||
cur = child
|
||||
}
|
||||
|
||||
if cur.end {
|
||||
val := string(chars[i+1-cur.depth : i+1])
|
||||
builder.WriteString(r.mapping[val])
|
||||
builder.WriteString(string(chars[i+1:]))
|
||||
// only matching this path, all previous paths are done
|
||||
if start >= i+1-cur.depth && i+1 >= size {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
chars = []rune(builder.String())
|
||||
size = len(chars)
|
||||
builder.Reset()
|
||||
break
|
||||
}
|
||||
|
||||
child = grandchild
|
||||
if child.end {
|
||||
end = j + 1
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
if end > 0 {
|
||||
i = j - 1
|
||||
builder.WriteString(r.mapping[string(chars[start:end])])
|
||||
} else {
|
||||
builder.WriteRune(chars[i])
|
||||
if !cur.end {
|
||||
builder.WriteString(string(chars[start:]))
|
||||
return builder.String()
|
||||
}
|
||||
start = -1
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
return string(chars)
|
||||
}
|
||||
|
42
core/stringx/replacer_fuzz_test.go
Normal file
42
core/stringx/replacer_fuzz_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
//go:build go1.18
|
||||
// +build go1.18
|
||||
|
||||
package stringx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func FuzzReplacerReplace(f *testing.F) {
|
||||
keywords := make(map[string]string)
|
||||
for i := 0; i < 20; i++ {
|
||||
keywords[Randn(rand.Intn(10)+5)] = Randn(rand.Intn(5) + 1)
|
||||
}
|
||||
rep := NewReplacer(keywords)
|
||||
printableKeywords := func() string {
|
||||
var buf strings.Builder
|
||||
for k, v := range keywords {
|
||||
fmt.Fprintf(&buf, "%q: %q,\n", k, v)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
f.Add(50)
|
||||
f.Fuzz(func(t *testing.T, n int) {
|
||||
text := Randn(rand.Intn(n%50+50) + 1)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("mapping: %s\ntext: %s", printableKeywords(), text)
|
||||
}
|
||||
}()
|
||||
val := rep.Replace(text)
|
||||
keys := rep.(*replacer).node.find([]rune(val))
|
||||
if len(keys) > 0 {
|
||||
t.Errorf("mapping: %s\ntext: %s\nresult: %s\nmatch: %v",
|
||||
printableKeywords(), text, val, keys)
|
||||
}
|
||||
})
|
||||
}
|
@ -15,6 +15,14 @@ func TestReplacer_Replace(t *testing.T) {
|
||||
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceOverlap(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"3d": "34",
|
||||
"bc": "23",
|
||||
}
|
||||
assert.Equal(t, "a234e", NewReplacer(mapping).Replace("abcde"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceSingleChar(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"二": "2",
|
||||
@ -42,3 +50,99 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"bcdf": "1235",
|
||||
"cde": "234",
|
||||
}
|
||||
assert.Equal(t, "ab234fg", NewReplacer(mapping).Replace("abcdefg"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceJumpToFailDup(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"bcdf": "1235",
|
||||
"ccde": "2234",
|
||||
}
|
||||
assert.Equal(t, "ab2234fg", NewReplacer(mapping).Replace("abccdefg"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceJumpToFailEnding(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"bcdf": "1235",
|
||||
"cdef": "2345",
|
||||
}
|
||||
assert.Equal(t, "ab2345", NewReplacer(mapping).Replace("abcdef"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceEmpty(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"bcdf": "1235",
|
||||
"cdef": "2345",
|
||||
}
|
||||
assert.Equal(t, "", NewReplacer(mapping).Replace(""))
|
||||
}
|
||||
|
||||
func TestFuzzCase1(t *testing.T) {
|
||||
keywords := map[string]string{
|
||||
"yQyJykiqoh": "xw",
|
||||
"tgN70z": "Q2P",
|
||||
"tXKhEn": "w1G8",
|
||||
"5nfOW1XZO": "GN",
|
||||
"f4Ov9i9nHD": "cT",
|
||||
"1ov9Q": "Y",
|
||||
"7IrC9n": "400i",
|
||||
"JQLxonpHkOjv": "XI",
|
||||
"DyHQ3c7": "Ygxux",
|
||||
"ffyqJi": "u",
|
||||
"UHuvXrbD8pni": "dN",
|
||||
"LIDzNbUlTX": "g",
|
||||
"yN9WZh2rkc8Q": "3U",
|
||||
"Vhk11rz8CObceC": "jf",
|
||||
"R0Rt4H2qChUQf": "7U5M",
|
||||
"MGQzzPCVKjV9": "yYz",
|
||||
"B5jUUl0u1XOY": "l4PZ",
|
||||
"pdvp2qfLgG8X": "BM562",
|
||||
"ZKl9qdApXJ2": "T",
|
||||
"37jnugkSevU66": "aOHFX",
|
||||
}
|
||||
rep := NewReplacer(keywords)
|
||||
text := "yjF8fyqJiiqrczOCVyoYbLvrMpnkj"
|
||||
val := rep.Replace(text)
|
||||
keys := rep.(*replacer).node.find([]rune(val))
|
||||
if len(keys) > 0 {
|
||||
t.Errorf("result: %s, match: %v", val, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFuzzCase2(t *testing.T) {
|
||||
keywords := map[string]string{
|
||||
"dmv2SGZvq9Yz": "TE",
|
||||
"rCL5DRI9uFP8": "hvsc8",
|
||||
"7pSA2jaomgg": "v",
|
||||
"kWSQvjVOIAxR": "Oje",
|
||||
"hgU5bYYkD3r6": "qCXu",
|
||||
"0eh6uI": "MMlt",
|
||||
"3USZSl85EKeMzw": "Pc",
|
||||
"JONmQSuXa": "dX",
|
||||
"EO1WIF": "G",
|
||||
"uUmFJGVmacjF": "1N",
|
||||
"DHpw7": "M",
|
||||
"NYB2bm": "CPya",
|
||||
"9FiNvBAHHNku5": "7FlDE",
|
||||
"tJi3I4WxcY": "q5",
|
||||
"sNJ8Z1ToBV0O": "tl",
|
||||
"0iOg72QcPo": "RP",
|
||||
"pSEqeL": "5KZ",
|
||||
"GOyYqTgmvQ": "9",
|
||||
"Qv4qCsj": "nl52E",
|
||||
"wNQ5tOutYu5s8": "6iGa",
|
||||
}
|
||||
rep := NewReplacer(keywords)
|
||||
text := "AoRxrdKWsGhFpXwVqMLWRL74OukwjBuBh0g7pSrk"
|
||||
val := rep.Replace(text)
|
||||
keys := rep.(*replacer).node.find([]rune(val))
|
||||
if len(keys) > 0 {
|
||||
t.Errorf("result: %s, match: %v", val, keys)
|
||||
}
|
||||
}
|
||||
|
@ -39,6 +39,8 @@ func NewTrie(words []string, opts ...TrieOption) Trie {
|
||||
n.add(word)
|
||||
}
|
||||
|
||||
n.build()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
@ -48,7 +50,7 @@ func (n *trieNode) Filter(text string) (sentence string, keywords []string, foun
|
||||
return text, nil, false
|
||||
}
|
||||
|
||||
scopes := n.findKeywordScopes(chars)
|
||||
scopes := n.find(chars)
|
||||
keywords = n.collectKeywords(chars, scopes)
|
||||
|
||||
for _, match := range scopes {
|
||||
@ -65,7 +67,7 @@ func (n *trieNode) FindKeywords(text string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
scopes := n.findKeywordScopes(chars)
|
||||
scopes := n.find(chars)
|
||||
return n.collectKeywords(chars, scopes)
|
||||
}
|
||||
|
||||
@ -85,48 +87,6 @@ func (n *trieNode) collectKeywords(chars []rune, scopes []scope) []string {
|
||||
return keywords
|
||||
}
|
||||
|
||||
func (n *trieNode) findKeywordScopes(chars []rune) []scope {
|
||||
var scopes []scope
|
||||
size := len(chars)
|
||||
start := -1
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
child, ok := n.children[chars[i]]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if start < 0 {
|
||||
start = i
|
||||
}
|
||||
if child.end {
|
||||
scopes = append(scopes, scope{
|
||||
start: start,
|
||||
stop: i + 1,
|
||||
})
|
||||
}
|
||||
|
||||
for j := i + 1; j < size; j++ {
|
||||
grandchild, ok := child.children[chars[j]]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
child = grandchild
|
||||
if child.end {
|
||||
scopes = append(scopes, scope{
|
||||
start: start,
|
||||
stop: j + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
start = -1
|
||||
}
|
||||
|
||||
return scopes
|
||||
}
|
||||
|
||||
func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) {
|
||||
for i := start; i < stop; i++ {
|
||||
chars[i] = n.mask
|
||||
|
@ -6,6 +6,17 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTrieSimple(t *testing.T) {
|
||||
trie := NewTrie([]string{
|
||||
"bc",
|
||||
"cd",
|
||||
})
|
||||
output, keywords, found := trie.Filter("abcd")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "a***", output)
|
||||
assert.ElementsMatch(t, []string{"bc", "cd"}, keywords)
|
||||
}
|
||||
|
||||
func TestTrie(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@ -14,11 +25,11 @@ func TestTrie(t *testing.T) {
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
input: "日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演",
|
||||
input: "日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演",
|
||||
output: "日本****兼电视、电影演员。*****女优是xx出道, ******们最精彩的表演是******表演",
|
||||
keywords: []string{
|
||||
"AV演员",
|
||||
"苍井空",
|
||||
"无名氏",
|
||||
"AV",
|
||||
"日本AV女优",
|
||||
"AV演员色情",
|
||||
@ -89,7 +100,7 @@ func TestTrie(t *testing.T) {
|
||||
"一不",
|
||||
"AV",
|
||||
"AV演员",
|
||||
"苍井空",
|
||||
"无名氏",
|
||||
"AV演员色情",
|
||||
"日本AV女优",
|
||||
})
|
||||
@ -145,20 +156,3 @@ func TestTrieNested(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "零########九十", output)
|
||||
}
|
||||
|
||||
func BenchmarkTrie(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
trie := NewTrie([]string{
|
||||
"A",
|
||||
"AV",
|
||||
"AV演员",
|
||||
"苍井空",
|
||||
"AV演员色情",
|
||||
"日本AV女优",
|
||||
})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
trie.Filter("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演")
|
||||
}
|
||||
}
|
||||
|
@ -3,10 +3,6 @@ package syncx
|
||||
import "sync"
|
||||
|
||||
type (
|
||||
// SharedCalls is an alias of SingleFlight.
|
||||
// Deprecated: use SingleFlight.
|
||||
SharedCalls = SingleFlight
|
||||
|
||||
// SingleFlight lets the concurrent calls with the same key to share the call result.
|
||||
// For example, A called F, before it's done, B called F. Then B would not execute F,
|
||||
// and shared the result returned by F which called by A.
|
||||
@ -37,12 +33,6 @@ func NewSingleFlight() SingleFlight {
|
||||
}
|
||||
}
|
||||
|
||||
// NewSharedCalls returns a SingleFlight.
|
||||
// Deprecated: use NewSingleFlight.
|
||||
func NewSharedCalls() SingleFlight {
|
||||
return NewSingleFlight()
|
||||
}
|
||||
|
||||
func (g *flightGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
c, done := g.createCall(key)
|
||||
if done {
|
||||
|
Loading…
Reference in New Issue
Block a user