genconf: properly read the value of the "if" node
[libcmdline.git] / src / genconf / expression.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <ctype.h>
5 #include <sys/queue.h>
6
7 #include "strictmalloc.h"
8 #include "expression.h"
9 #include "conf_parser.h"
10 #include "confnode.h"
11 #include "conf_htable.h"
12
13 /* XXX prefix all with "expression_" */
14
15 struct expr_node *node_create(void)
16 {
17         struct expr_node *n;
18         n = strictmalloc(sizeof(struct expr_node));
19         return n;
20 }
21
22 /* free an expression */
23 void expression_free(struct expr_node *exp)
24 {
25         if (exp == NULL)
26                 return;
27
28         if (exp->left) {
29                 expression_free(exp->left);
30                 exp->left = NULL;
31         }
32
33         if (exp->right) {
34                 expression_free(exp->right);
35                 exp->right = NULL;
36         }
37
38         free(exp);
39 }
40
41 const char *op_print(const struct expr_op *op)
42 {
43         switch (op->optype) {
44         case OP_OR:
45                 return "||";
46         case OP_AND:
47                 return "&&";
48         case OP_EQUAL:
49                 return "=";
50         case OP_OBRACKET:
51                 return "(";
52         case OP_NOT:
53                 return "!";
54         case OP_CBRACKET:
55                 return ")";
56         case OP_VAR:
57                 return op->name;
58         default:
59                 return NULL;
60         }
61 }
62
63 void __dump(struct expr_node *n)
64 {
65         if (n == NULL)
66                 return;
67
68         if (n->left) {
69                 printf("\"%p <%s>\" -> \"%p <%s>\"\n",
70                        n, op_print(&n->op),
71                        n->left, op_print(&n->left->op));
72                 __dump(n->left);
73         }
74
75         if (n->right) {
76                 printf("\"%p <%s>\" -> \"%p <%s>\"\n",
77                        n, op_print(&n->op),
78                        n->right, op_print(&n->right->op));
79                 __dump(n->right);
80         }
81 }
82
83 void dump(struct expr_node *n)
84 {
85         printf("digraph unix {\n"
86                "size=\"6,6\";\n"
87                "node [color=lightblue2, style=filled];\n");
88         __dump(n);
89         printf("}\n");
90 }
91
92 int isbiop(enum op_type optype)
93 {
94         if (optype == OP_AND ||
95             optype == OP_OR ||
96             optype == OP_EQUAL)
97                 return 1;
98         return 0;
99 }
100
101 /*
102  * Dump the expression 'node' as a string into the buffer 'buf' of
103  * length 'len'. The string is nul-terminated. Return the number of
104  * written bytes on success (not including \0), else return -1.
105  */
106 int expression_to_str(struct expr_node *node, char *buf, int len)
107 {
108         int n, orig_len;
109
110         orig_len = len;
111
112         if (isbiop(node->op.optype) || node->op.optype == OP_OBRACKET) {
113                 n = snprintf(buf, len, "(");
114                 if (n == -1 || n >= len || len <= 0)
115                         return -1;
116                 buf += n;
117                 len -= n;
118         }
119
120         if (isbiop(node->op.optype)) {
121                 n = expression_to_str(node->left, buf, len);
122                 if (n == -1 || n >= len || len <= 0)
123                         return -1;
124                 buf += n;
125                 len -= n;
126         }
127
128         if (node->op.optype != OP_OBRACKET) {
129                 if (isbiop(node->op.optype))
130                         n = snprintf(buf, len, " %s ", op_print(&node->op));
131                 else
132                         n = snprintf(buf, len, "%s", op_print(&node->op));
133                 if (n == -1 || n >= len || len <= 0)
134                         return -1;
135                 buf += n;
136                 len -= n;
137         }
138
139         if (node->op.optype != OP_VAR) {
140                 n = expression_to_str(node->right, buf, len);
141                 if (n == -1 || n >= len || len <= 0)
142                         return -1;
143                 buf += n;
144                 len -= n;
145         }
146
147         if (isbiop(node->op.optype) || node->op.optype == OP_OBRACKET) {
148                 n = snprintf(buf, len, ")");
149                 if (n == -1 || n >= len || len <= 0)
150                         return -1;
151                 buf += n;
152                 len -= n;
153         }
154
155         return orig_len - len;
156 }
157
158 /*
159  * Evaluate an expression. Return -1 on error, else the value of the
160  * expression (greater or equal than 0).
161  */
162 int expression_eval(struct expr_node *node)
163 {
164         struct confnode *conf;
165
166         switch (node->op.optype) {
167         case OP_OR:
168                 return expression_eval(node->left) || expression_eval(node->right);
169         case OP_AND:
170                 return expression_eval(node->left) && expression_eval(node->right);
171         case OP_EQUAL:
172                 return expression_eval(node->left) == expression_eval(node->right);
173         case OP_OBRACKET:
174                 return expression_eval(node->right);
175         case OP_NOT:
176                 return !expression_eval(node->right);
177         case OP_VAR:
178                  conf = conf_htable_lookup(node->op.name);
179                  if (conf == NULL)
180                          return 0;
181                  return confnode_get_boolvalue(conf);
182         default:
183                 printf("%s(): bad operator\n", __FUNCTION__);
184                 return 0;
185         }
186 }
187
188 int get_brac_len(const char *buf)
189 {
190         const char *s = buf;
191         int i = 1;
192
193         if (*s != '(')
194                 return -1;
195         s++;
196
197         while(*s != '\0' && i != 0) {
198                 if (*s == ')')
199                         i--;
200                 if (*s == '(')
201                         i++;
202                 s++;
203         }
204         if (i != 0)
205                 return -1;
206         return s-buf;
207 }
208
209 /*
210  * XXX
211  */
212  static struct expr_op *get_operand(const char *buf, unsigned *eatlen)
213 {
214         struct expr_op *op;
215         const char *s;
216         unsigned len;
217
218         //printf("%s\n", buf);
219         op = strictmalloc(sizeof(struct expr_op));
220
221         s = buf;
222
223         switch (s[0]) {
224         case '#':
225                 op->optype = OP_EOF;
226                 *eatlen = s - buf + 1;
227                 return op;
228         case '\0':
229                 op->optype = OP_EOF;
230                 *eatlen = s - buf + 1;
231                 return op;
232         case '(':
233                 op->optype = OP_OBRACKET;
234                 *eatlen = s - buf + 1;
235                 return op;
236         case ')':
237                 op->optype = OP_CBRACKET;
238                 *eatlen = s - buf + 1;
239                 return op;
240         case '!':
241                 op->optype = OP_NOT;
242                 *eatlen = s - buf + 1;
243                 return op;
244         case '&':
245                 op->optype = OP_AND;
246                 if (s[1] != '&') {
247                         free(op);
248                         return NULL;
249                 }
250                 *eatlen = s - buf + 2;
251                 return op;
252         case '|':
253                 op->optype = OP_OR;
254                 if (s[1] != '|') {
255                         free(op);
256                         return NULL;
257                 }
258                 *eatlen = s - buf + 2;
259                 return op;
260         case '=':
261                 op->optype = OP_EQUAL;
262                 *eatlen = s - buf + 1;
263                 return op;
264         default:
265                 break;
266         }
267
268         /* It's a variable, get name */
269         while (*s != '\0' &&
270                *s != '(' &&
271                *s != ')' &&
272                *s != '&' &&
273                *s != '|' &&
274                *s != '!' &&
275                *s != '=' &&
276                *s != '#' &&
277                !isspace(*s))
278                 s++;
279
280         op->optype = OP_VAR;
281         *eatlen = s - buf;
282
283         /* alloc string for variable name */
284         len = s - buf;
285         op->name = strictmalloc(len+1);
286         memcpy(op->name, buf, len);
287         op->name[len+1] = '\0';
288         return op;
289 }
290
291 struct expr_node *parse_expression(char *buf)
292 {
293         struct expr_node *top = NULL, *n, *tmp;
294         char *s = buf;
295         struct expr_op *op;
296         int len;
297         unsigned eatlen;
298
299         //printf("parse <%s>\n", buf);
300
301         if (*s == '\0')
302                 return NULL;
303
304         while (1) {
305
306                 /* skip spaces */
307                 while (isspace(*s))
308                         s++;
309
310                 op = get_operand(s, &eatlen);
311                 if (op == NULL) {
312                         printf("Parse error\n");
313                         return NULL; /* XXX */
314                 }
315                 if (op->optype == OP_EOF)
316                         break;
317
318                 //printf("%s\n", op_print(op));
319
320                 n = node_create();
321                 n->op = *op;
322                 free(op);
323
324                 switch (n->op.optype) {
325                 case OP_OR:
326                 case OP_AND:
327                 case OP_EQUAL:
328                         if (top == NULL ||
329                             (!isbiop(top->op.optype) ||
330                              top->op.optype > n->op.optype)) {
331                                 n->left = top;
332                                 top = n;
333                         }
334                         else {
335                                 tmp = top->right;
336                                 top->right = n;
337                                 n->left = tmp;
338                         }
339                         break;
340                 case OP_OBRACKET:
341                         len = get_brac_len(s);
342                         if (len < 0) {
343                                 printf("Parse error, cannot find closing bracket\n");
344                                 return NULL; /* XXX */
345                         }
346                         s[len-1] = '\0';
347                         n->right = parse_expression(s+1);
348                         if (n->right == NULL)
349                                 return NULL;
350                         s[len-1] = ')';
351                         s += len - 1;
352
353                         if (top == NULL)
354                                 top = n;
355                         else {
356                                 tmp = top;
357                                 while (tmp->right)
358                                         tmp = tmp->right;
359                                 tmp->right = n;
360                         }
361
362                         break;
363                 case OP_NOT:
364                 case OP_VAR:
365                         if (top == NULL)
366                                 top = n;
367                         else {
368                                 tmp = top;
369                                 while (tmp->right)
370                                         tmp = tmp->right;
371                                 tmp->right = n;
372                         }
373                         break;
374                 default:
375                         printf("Parse error\n");
376                         return NULL; /* XXX */
377                 }
378
379                 s += eatlen;
380         }
381
382         return top;
383 }
384
385 #if 0
386 int main(void)
387 {
388         char s[] = "!(A&&! (B || CONFIG_C ) )";
389         struct expr_node *n;
390
391         n = parse_expression(s);
392
393         //printf("--------------\n");
394         dump(n);
395         return 0; /* XXX free ! */
396 }
397 #endif