-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgraph.go
More file actions
179 lines (162 loc) · 4.12 KB
/
graph.go
File metadata and controls
179 lines (162 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package resrap
import (
"sort"
"strings"
"github.com/golang-collections/collections/stack"
)
type nextoption struct {
node *syntaxNode
probability float32
}
type nodeType int8
const (
start nodeType = iota
header
jump
end
ch
rx
pointer
idk
)
type syntaxNode struct {
next []nextoption //All options from here
cf []float32 //Cumulative frequency of all the options
id uint32 //The id of the node
typ nodeType
pointer uint32 //In case its a pointer type
}
func (s *syntaxNode) AddEdgeNext(g *syntaxGraph, node *syntaxNode, probability float32) {
newNode := nextoption{node: node, probability: probability}
s.next = append(s.next, newNode)
g.nodeRef[node.id] = node
}
type syntaxGraph struct {
nodeRef map[uint32]*syntaxNode
namemap map[string]uint32
charmap map[uint32]string
regexhandler regexer
prng prng
}
func (s *syntaxGraph) GetNode(id uint32, typ nodeType) *syntaxNode {
if s.nodeRef[id] != nil {
return s.nodeRef[id]
}
newNode := &syntaxNode{nil, nil, id, typ, 0}
s.nodeRef[id] = newNode
return newNode
}
func newSyntaxGraph() syntaxGraph {
return syntaxGraph{
nodeRef: make(map[uint32]*syntaxNode),
}
}
func (s *syntaxGraph) Normalize() {
//We will even out all the children going through the whole graph
//And also create a cumulative frequency graph that will help in traversing
for _, node := range s.nodeRef {
//Extract the elements in a diff array
var CF []float32
var cf float32
var sum float32
for _, n := range node.next {
CF = append(CF, n.probability)
sum += n.probability
}
// Divide each element by the sum
for i, _ := range CF {
CF[i] = cf + CF[i]/sum
// convert it into a CF
cf = CF[i]
}
node.cf = CF
//Now we have a cool little CF Array Normalized to 1
// When picking random values, we pick one between 0 and 1
// And then choose its closest value from the array
// For probability based selections
}
}
func (s *syntaxGraph) GraphWalk(prng *prng, start string, tokens int) string {
var result strings.Builder
jumpStack := stack.New()
startingNode := s.nodeRef[s.namemap[start]]
if startingNode == nil {
return ""
}
printedTokens := 0
current := startingNode
for current != nil {
if printedTokens >= tokens {
return result.String()
}
// Process logic only if name starts with ' or [
if current.typ == ch {
// Extract content between quotes and handle escape sequences
content := s.charmap[current.id]
unescaped := unescapeString(content)
printedTokens++
result.WriteString(unescaped)
} else if current.typ == rx {
result.WriteString(s.regexhandler.GenerateString(s.charmap[current.id], prng))
} else if current.typ == pointer {
jumpStack.Push(current.next[0].node.id)
current = s.GetNode(current.pointer, header)
continue // Skip the normal next node selection
} else if current.typ == end {
if jumpStack.Len() != 0 {
nameInt := jumpStack.Pop()
id, ok := nameInt.(uint32)
if !ok {
break
}
current = s.GetNode(id, idk)
continue // Skip the normal next node selection
}
}
// move to next (randomly selected if multiple options)
if len(current.next) > 0 {
value := float32(prng.Random())
index := sort.Search(len(current.cf), func(i int) bool {
return current.cf[i] >= value
})
current = current.next[index].node
} else {
current = nil
}
}
return result.String()
}
// Helper function to handle escape sequences
func unescapeString(s string) string {
result := make([]byte, 0, len(s))
for i := 0; i < len(s); i++ {
if s[i] == '\\' && i+1 < len(s) {
switch s[i+1] {
case 'n':
result = append(result, '\n')
i++ // skip the next character
case 't':
result = append(result, '\t')
i++
case 'r':
result = append(result, '\r')
i++
case '\\':
result = append(result, '\\')
i++
case '\'':
result = append(result, '\'')
i++
case '"':
result = append(result, '"')
i++
default:
// If it's not a recognized escape sequence, keep both characters
result = append(result, s[i])
}
} else {
result = append(result, s[i])
}
}
return string(result)
}