1 
2 //          Copyright Tim Schendekehl 2023.
3 // Distributed under the Boost Software License, Version 1.0.
4 //    (See accompanying file LICENSE_1_0.txt or copy at
5 //          https://www.boost.org/LICENSE_1_0.txt)
6 
7 module dparsergen.core.nonterminalunion;
8 import dparsergen.core.grammarinfo;
9 import dparsergen.core.utils;
10 import std.conv;
11 import std.typetuple;
12 
13 private ptrdiff_t simpleCountUntil(const SymbolID[] haystack, SymbolID needle)
14 {
15     foreach (i, x; haystack)
16         if (x == needle)
17             return i;
18     return -1;
19 }
20 
21 template GenericNonterminalUnion(alias CreatorInstance)
22 {
23     /**
24     Tagged union of types for nonterminals. Used internally by the parser.
25     The tree creator can also choose a custom implementation.
26     */
27     struct Union(SymbolID singleNonterminalID, size_t maxSize)
28     {
29         alias Location = CreatorInstance.Location;
30 
31         template NonterminalType(SymbolID nonterminalID)
32                 if ((nonterminalID >= CreatorInstance.startNonterminalID
33                     && nonterminalID < CreatorInstance.endNonterminalID)
34                     || nonterminalID == SymbolID.max)
35         {
36             alias NonterminalType = CreatorInstance.NonterminalType!nonterminalID;
37         }
38 
39         static if (singleNonterminalID == SymbolID.max)
40         {
41             static immutable nonterminalIDs = () {
42                 SymbolID[] r;
43                 static foreach (i; CreatorInstance.startNonterminalID
44                         .. CreatorInstance.endNonterminalID)
45                 {
46                     if (NonterminalType!i.sizeof <= maxSize)
47                         r ~= i;
48                 }
49                 return r;
50             }();
51             union
52             {
53                 staticMap!(NonterminalType, arrayToAliasSeq!(nonterminalIDs)) values;
54             }
55 
56             SymbolID nonterminalID = SymbolID.max;
57         }
58         else
59         {
60             AliasSeq!(NonterminalType!(singleNonterminalID)) values;
61             enum nonterminalID = singleNonterminalID;
62             static immutable nonterminalIDs = [singleNonterminalID];
63         }
64 
65         inout(NonterminalType!nonterminalID2) get(SymbolID nonterminalID2)() inout
66         in
67         {
68             assert(nonterminalID2 == nonterminalID, text(nonterminalID2, "  ", nonterminalID));
69         }
70         do
71         {
72             enum k = simpleCountUntil(nonterminalIDs, nonterminalID2);
73             static assert(k != -1, text(nonterminalID2, " ", nonterminalIDs,
74                     " ", singleNonterminalID, " ", maxSize));
75             return values[k];
76         }
77 
78         auto get(nonterminalID2s...)() inout if (nonterminalID2s.length >= 2)
79         {
80             foreach (nonterminalID2; nonterminalID2s)
81             {
82                 if (nonterminalID2 == nonterminalID)
83                     return get!nonterminalID2();
84             }
85             assert(false);
86         }
87 
88         inout(T) getT(T)() inout
89         {
90             static foreach (k, nonterminalID2; nonterminalIDs)
91             {
92                 static if (is(const(typeof(values[k])) : const(T)))
93                 {
94                     if (nonterminalID2 == nonterminalID)
95                         return get!nonterminalID2();
96                 }
97             }
98             assert(false);
99         }
100 
101         void setT(T)(T data, SymbolID nonterminalID3)
102         {
103             static foreach (k, nonterminalID2; nonterminalIDs)
104             {
105                 static if (is(const(typeof(values[k])) : const(T)))
106                 {
107                     if (nonterminalID2 == nonterminalID3)
108                     {
109                         nonterminalID = nonterminalID3;
110                         values[k] = data;
111                         return;
112                     }
113                 }
114             }
115             assert(false);
116         }
117 
118         bool isType(T)() const
119         {
120             static foreach (k, nonterminalID2; nonterminalIDs)
121             {
122                 static if (is(const(typeof(values[k])) : const(T)))
123                 {
124                     if (nonterminalID2 == nonterminalID)
125                         return true;
126                 }
127             }
128             return false;
129         }
130 
131         static Union create(T)(SymbolID nonterminalID, T tree)
132         {
133             Union r;
134             enum i = staticIndexOf!(T, typeof(values));
135             r.values[i] = tree;
136 
137             static if (singleNonterminalID == SymbolID.max)
138                 r.nonterminalID = nonterminalID;
139             else
140                 assert(r.nonterminalID == nonterminalID, text(nonterminalID, nonterminalIDs));
141             return r;
142         }
143 
144         static Union create()(SymbolID nonterminalID)
145         {
146             Union r;
147             static if (singleNonterminalID == SymbolID.max)
148             {
149                 r.nonterminalID = nonterminalID;
150                 bool found;
151                 static foreach (i, nonterminalID2; nonterminalIDs)
152                 {
153                     if (nonterminalID2 == nonterminalID)
154                     {
155                         r.values[i] = typeof(r.values[i]).init;
156                         found = true;
157                     }
158                 }
159                 assert(found, text(nonterminalID, " ", nonterminalIDs));
160             }
161             else
162             {
163                 assert(r.nonterminalID == nonterminalID, text(nonterminalID, nonterminalIDs));
164                 r.values[0] = typeof(r.values[0]).init;
165             }
166             return r;
167         }
168 
169         void opAssign(SymbolID singleNonterminalID2, size_t maxSize2)(
170                 Union!(singleNonterminalID2, maxSize2) rhs)
171                 if (singleNonterminalID2 != singleNonterminalID || maxSize2 != maxSize)
172         {
173             static if (singleNonterminalID2 == SymbolID.max)
174             {
175                 if (rhs.nonterminalID == SymbolID.max)
176                 {
177                     static if (singleNonterminalID == SymbolID.max)
178                     {
179                         this.nonterminalID = SymbolID.max;
180                         return;
181                     }
182                     else
183                     {
184                         assert(false);
185                     }
186                 }
187             }
188             static foreach (i; 0 .. rhs.nonterminalIDs.length)
189             {
190                 {
191                     enum n = rhs.nonterminalIDs[i];
192                     enum j = simpleCountUntil(nonterminalIDs, n);
193                     static if (j >= 0)
194                         if (n == rhs.nonterminalID)
195                         {
196                             values[j] = rhs.get!n;
197 
198                             static if (nonterminalIDs.length != 1)
199                                 nonterminalID = n;
200                             return;
201                         }
202                 }
203             }
204             assert(0);
205         }
206     }
207 
208     /// ditto
209     template Union(alias nonterminalIDs)
210     {
211         static assert(nonterminalIDs.length > 0);
212         alias Union = Union!(Params!(nonterminalIDs));
213     }
214 
215     private template Params(alias nonterminalIDs)
216     {
217         alias Params = AliasSeq!(
218                 (nonterminalIDs.length == 1) ? nonterminalIDs[0] : SymbolID.max, () {
219             size_t max = 0;
220             static foreach (i; 0 .. nonterminalIDs.length)
221             {
222                 {
223                     enum n = nonterminalIDs[i];
224                     if (CreatorInstance.NonterminalType!n.sizeof > max)
225                         max = CreatorInstance.NonterminalType!n.sizeof;
226                 }
227             }
228             return max;
229         }());
230     }
231 }