You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

806 lines
24 KiB

  1. /* AST Optimizer */
  2. #include "Python.h"
  3. #include "Python-ast.h"
  4. #include "node.h"
  5. #include "ast.h"
  6. /* TODO: is_const and get_const_value are copied from Python/compile.c.
  7. It should be deduped in the future. Maybe, we can include this file
  8. from compile.c?
  9. */
  10. static int
  11. is_const(expr_ty e)
  12. {
  13. switch (e->kind) {
  14. case Constant_kind:
  15. case Num_kind:
  16. case Str_kind:
  17. case Bytes_kind:
  18. case Ellipsis_kind:
  19. case NameConstant_kind:
  20. return 1;
  21. default:
  22. return 0;
  23. }
  24. }
  25. static PyObject *
  26. get_const_value(expr_ty e)
  27. {
  28. switch (e->kind) {
  29. case Constant_kind:
  30. return e->v.Constant.value;
  31. case Num_kind:
  32. return e->v.Num.n;
  33. case Str_kind:
  34. return e->v.Str.s;
  35. case Bytes_kind:
  36. return e->v.Bytes.s;
  37. case Ellipsis_kind:
  38. return Py_Ellipsis;
  39. case NameConstant_kind:
  40. return e->v.NameConstant.value;
  41. default:
  42. Py_UNREACHABLE();
  43. }
  44. }
  45. static int
  46. make_const(expr_ty node, PyObject *val, PyArena *arena)
  47. {
  48. if (val == NULL) {
  49. if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
  50. return 0;
  51. }
  52. PyErr_Clear();
  53. return 1;
  54. }
  55. if (PyArena_AddPyObject(arena, val) < 0) {
  56. Py_DECREF(val);
  57. return 0;
  58. }
  59. node->kind = Constant_kind;
  60. node->v.Constant.value = val;
  61. return 1;
  62. }
  63. #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
  64. static PyObject*
  65. unary_not(PyObject *v)
  66. {
  67. int r = PyObject_IsTrue(v);
  68. if (r < 0)
  69. return NULL;
  70. return PyBool_FromLong(!r);
  71. }
  72. static int
  73. fold_unaryop(expr_ty node, PyArena *arena, int optimize)
  74. {
  75. expr_ty arg = node->v.UnaryOp.operand;
  76. if (!is_const(arg)) {
  77. /* Fold not into comparison */
  78. if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
  79. asdl_seq_LEN(arg->v.Compare.ops) == 1) {
  80. /* Eq and NotEq are often implemented in terms of one another, so
  81. folding not (self == other) into self != other breaks implementation
  82. of !=. Detecting such cases doesn't seem worthwhile.
  83. Python uses </> for 'is subset'/'is superset' operations on sets.
  84. They don't satisfy not folding laws. */
  85. int op = asdl_seq_GET(arg->v.Compare.ops, 0);
  86. switch (op) {
  87. case Is:
  88. op = IsNot;
  89. break;
  90. case IsNot:
  91. op = Is;
  92. break;
  93. case In:
  94. op = NotIn;
  95. break;
  96. case NotIn:
  97. op = In;
  98. break;
  99. default:
  100. op = 0;
  101. }
  102. if (op) {
  103. asdl_seq_SET(arg->v.Compare.ops, 0, op);
  104. COPY_NODE(node, arg);
  105. return 1;
  106. }
  107. }
  108. return 1;
  109. }
  110. typedef PyObject *(*unary_op)(PyObject*);
  111. static const unary_op ops[] = {
  112. [Invert] = PyNumber_Invert,
  113. [Not] = unary_not,
  114. [UAdd] = PyNumber_Positive,
  115. [USub] = PyNumber_Negative,
  116. };
  117. PyObject *newval = ops[node->v.UnaryOp.op](get_const_value(arg));
  118. return make_const(node, newval, arena);
  119. }
  120. /* Check whether a collection doesn't containing too much items (including
  121. subcollections). This protects from creating a constant that needs
  122. too much time for calculating a hash.
  123. "limit" is the maximal number of items.
  124. Returns the negative number if the total number of items exceeds the
  125. limit. Otherwise returns the limit minus the total number of items.
  126. */
  127. static Py_ssize_t
  128. check_complexity(PyObject *obj, Py_ssize_t limit)
  129. {
  130. if (PyTuple_Check(obj)) {
  131. Py_ssize_t i;
  132. limit -= PyTuple_GET_SIZE(obj);
  133. for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
  134. limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
  135. }
  136. return limit;
  137. }
  138. else if (PyFrozenSet_Check(obj)) {
  139. Py_ssize_t i = 0;
  140. PyObject *item;
  141. Py_hash_t hash;
  142. limit -= PySet_GET_SIZE(obj);
  143. while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
  144. limit = check_complexity(item, limit);
  145. }
  146. }
  147. return limit;
  148. }
  149. #define MAX_INT_SIZE 128 /* bits */
  150. #define MAX_COLLECTION_SIZE 256 /* items */
  151. #define MAX_STR_SIZE 4096 /* characters */
  152. #define MAX_TOTAL_ITEMS 1024 /* including nested collections */
  153. static PyObject *
  154. safe_multiply(PyObject *v, PyObject *w)
  155. {
  156. if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
  157. size_t vbits = _PyLong_NumBits(v);
  158. size_t wbits = _PyLong_NumBits(w);
  159. if (vbits == (size_t)-1 || wbits == (size_t)-1) {
  160. return NULL;
  161. }
  162. if (vbits + wbits > MAX_INT_SIZE) {
  163. return NULL;
  164. }
  165. }
  166. else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
  167. Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
  168. PySet_GET_SIZE(w);
  169. if (size) {
  170. long n = PyLong_AsLong(v);
  171. if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
  172. return NULL;
  173. }
  174. if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
  175. return NULL;
  176. }
  177. }
  178. }
  179. else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
  180. Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
  181. PyBytes_GET_SIZE(w);
  182. if (size) {
  183. long n = PyLong_AsLong(v);
  184. if (n < 0 || n > MAX_STR_SIZE / size) {
  185. return NULL;
  186. }
  187. }
  188. }
  189. else if (PyLong_Check(w) &&
  190. (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
  191. PyUnicode_Check(v) || PyBytes_Check(v)))
  192. {
  193. return safe_multiply(w, v);
  194. }
  195. return PyNumber_Multiply(v, w);
  196. }
  197. static PyObject *
  198. safe_power(PyObject *v, PyObject *w)
  199. {
  200. if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
  201. size_t vbits = _PyLong_NumBits(v);
  202. size_t wbits = PyLong_AsSize_t(w);
  203. if (vbits == (size_t)-1 || wbits == (size_t)-1) {
  204. return NULL;
  205. }
  206. if (vbits > MAX_INT_SIZE / wbits) {
  207. return NULL;
  208. }
  209. }
  210. return PyNumber_Power(v, w, Py_None);
  211. }
  212. static PyObject *
  213. safe_lshift(PyObject *v, PyObject *w)
  214. {
  215. if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
  216. size_t vbits = _PyLong_NumBits(v);
  217. size_t wbits = PyLong_AsSize_t(w);
  218. if (vbits == (size_t)-1 || wbits == (size_t)-1) {
  219. return NULL;
  220. }
  221. if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
  222. return NULL;
  223. }
  224. }
  225. return PyNumber_Lshift(v, w);
  226. }
  227. static PyObject *
  228. safe_mod(PyObject *v, PyObject *w)
  229. {
  230. if (PyUnicode_Check(v) || PyBytes_Check(v)) {
  231. return NULL;
  232. }
  233. return PyNumber_Remainder(v, w);
  234. }
  235. static int
  236. fold_binop(expr_ty node, PyArena *arena, int optimize)
  237. {
  238. expr_ty lhs, rhs;
  239. lhs = node->v.BinOp.left;
  240. rhs = node->v.BinOp.right;
  241. if (!is_const(lhs) || !is_const(rhs)) {
  242. return 1;
  243. }
  244. PyObject *lv = get_const_value(lhs);
  245. PyObject *rv = get_const_value(rhs);
  246. PyObject *newval;
  247. switch (node->v.BinOp.op) {
  248. case Add:
  249. newval = PyNumber_Add(lv, rv);
  250. break;
  251. case Sub:
  252. newval = PyNumber_Subtract(lv, rv);
  253. break;
  254. case Mult:
  255. newval = safe_multiply(lv, rv);
  256. break;
  257. case Div:
  258. newval = PyNumber_TrueDivide(lv, rv);
  259. break;
  260. case FloorDiv:
  261. newval = PyNumber_FloorDivide(lv, rv);
  262. break;
  263. case Mod:
  264. newval = safe_mod(lv, rv);
  265. break;
  266. case Pow:
  267. newval = safe_power(lv, rv);
  268. break;
  269. case LShift:
  270. newval = safe_lshift(lv, rv);
  271. break;
  272. case RShift:
  273. newval = PyNumber_Rshift(lv, rv);
  274. break;
  275. case BitOr:
  276. newval = PyNumber_Or(lv, rv);
  277. break;
  278. case BitXor:
  279. newval = PyNumber_Xor(lv, rv);
  280. break;
  281. case BitAnd:
  282. newval = PyNumber_And(lv, rv);
  283. break;
  284. default: // Unknown operator
  285. return 1;
  286. }
  287. return make_const(node, newval, arena);
  288. }
  289. static PyObject*
  290. make_const_tuple(asdl_seq *elts)
  291. {
  292. for (int i = 0; i < asdl_seq_LEN(elts); i++) {
  293. expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
  294. if (!is_const(e)) {
  295. return NULL;
  296. }
  297. }
  298. PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
  299. if (newval == NULL) {
  300. return NULL;
  301. }
  302. for (int i = 0; i < asdl_seq_LEN(elts); i++) {
  303. expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
  304. PyObject *v = get_const_value(e);
  305. Py_INCREF(v);
  306. PyTuple_SET_ITEM(newval, i, v);
  307. }
  308. return newval;
  309. }
  310. static int
  311. fold_tuple(expr_ty node, PyArena *arena, int optimize)
  312. {
  313. PyObject *newval;
  314. if (node->v.Tuple.ctx != Load)
  315. return 1;
  316. newval = make_const_tuple(node->v.Tuple.elts);
  317. return make_const(node, newval, arena);
  318. }
  319. static int
  320. fold_subscr(expr_ty node, PyArena *arena, int optimize)
  321. {
  322. PyObject *newval;
  323. expr_ty arg, idx;
  324. slice_ty slice;
  325. arg = node->v.Subscript.value;
  326. slice = node->v.Subscript.slice;
  327. if (node->v.Subscript.ctx != Load ||
  328. !is_const(arg) ||
  329. /* TODO: handle other types of slices */
  330. slice->kind != Index_kind ||
  331. !is_const(slice->v.Index.value))
  332. {
  333. return 1;
  334. }
  335. idx = slice->v.Index.value;
  336. newval = PyObject_GetItem(get_const_value(arg), get_const_value(idx));
  337. return make_const(node, newval, arena);
  338. }
  339. /* Change literal list or set of constants into constant
  340. tuple or frozenset respectively. Change literal list of
  341. non-constants into tuple.
  342. Used for right operand of "in" and "not in" tests and for iterable
  343. in "for" loop and comprehensions.
  344. */
  345. static int
  346. fold_iter(expr_ty arg, PyArena *arena, int optimize)
  347. {
  348. PyObject *newval;
  349. if (arg->kind == List_kind) {
  350. /* First change a list into tuple. */
  351. asdl_seq *elts = arg->v.List.elts;
  352. Py_ssize_t n = asdl_seq_LEN(elts);
  353. for (Py_ssize_t i = 0; i < n; i++) {
  354. expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
  355. if (e->kind == Starred_kind) {
  356. return 1;
  357. }
  358. }
  359. expr_context_ty ctx = arg->v.List.ctx;
  360. arg->kind = Tuple_kind;
  361. arg->v.Tuple.elts = elts;
  362. arg->v.Tuple.ctx = ctx;
  363. /* Try to create a constant tuple. */
  364. newval = make_const_tuple(elts);
  365. }
  366. else if (arg->kind == Set_kind) {
  367. newval = make_const_tuple(arg->v.Set.elts);
  368. if (newval) {
  369. Py_SETREF(newval, PyFrozenSet_New(newval));
  370. }
  371. }
  372. else {
  373. return 1;
  374. }
  375. return make_const(arg, newval, arena);
  376. }
  377. static int
  378. fold_compare(expr_ty node, PyArena *arena, int optimize)
  379. {
  380. asdl_int_seq *ops;
  381. asdl_seq *args;
  382. Py_ssize_t i;
  383. ops = node->v.Compare.ops;
  384. args = node->v.Compare.comparators;
  385. /* TODO: optimize cases with literal arguments. */
  386. /* Change literal list or set in 'in' or 'not in' into
  387. tuple or frozenset respectively. */
  388. i = asdl_seq_LEN(ops) - 1;
  389. int op = asdl_seq_GET(ops, i);
  390. if (op == In || op == NotIn) {
  391. if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, optimize)) {
  392. return 0;
  393. }
  394. }
  395. return 1;
  396. }
  397. static int astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_);
  398. static int astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_);
  399. static int astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_);
  400. static int astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_);
  401. static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_);
  402. static int astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_);
  403. static int astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_);
  404. static int astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_);
  405. static int astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_);
  406. static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_);
  407. #define CALL(FUNC, TYPE, ARG) \
  408. if (!FUNC((ARG), ctx_, optimize_)) \
  409. return 0;
  410. #define CALL_OPT(FUNC, TYPE, ARG) \
  411. if ((ARG) != NULL && !FUNC((ARG), ctx_, optimize_)) \
  412. return 0;
  413. #define CALL_SEQ(FUNC, TYPE, ARG) { \
  414. int i; \
  415. asdl_seq *seq = (ARG); /* avoid variable capture */ \
  416. for (i = 0; i < asdl_seq_LEN(seq); i++) { \
  417. TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
  418. if (elt != NULL && !FUNC(elt, ctx_, optimize_)) \
  419. return 0; \
  420. } \
  421. }
  422. #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
  423. int i; \
  424. asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
  425. for (i = 0; i < asdl_seq_LEN(seq); i++) { \
  426. TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
  427. if (!FUNC(elt, ctx_, optimize_)) \
  428. return 0; \
  429. } \
  430. }
  431. static int
  432. astfold_body(asdl_seq *stmts, PyArena *ctx_, int optimize_)
  433. {
  434. int docstring = _PyAST_GetDocString(stmts) != NULL;
  435. CALL_SEQ(astfold_stmt, stmt_ty, stmts);
  436. if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
  437. stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
  438. asdl_seq *values = _Py_asdl_seq_new(1, ctx_);
  439. if (!values) {
  440. return 0;
  441. }
  442. asdl_seq_SET(values, 0, st->v.Expr.value);
  443. expr_ty expr = JoinedStr(values, st->lineno, st->col_offset, ctx_);
  444. if (!expr) {
  445. return 0;
  446. }
  447. st->v.Expr.value = expr;
  448. }
  449. return 1;
  450. }
  451. static int
  452. astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_)
  453. {
  454. switch (node_->kind) {
  455. case Module_kind:
  456. CALL(astfold_body, asdl_seq, node_->v.Module.body);
  457. break;
  458. case Interactive_kind:
  459. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Interactive.body);
  460. break;
  461. case Expression_kind:
  462. CALL(astfold_expr, expr_ty, node_->v.Expression.body);
  463. break;
  464. case Suite_kind:
  465. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Suite.body);
  466. break;
  467. default:
  468. break;
  469. }
  470. return 1;
  471. }
  472. static int
  473. astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_)
  474. {
  475. switch (node_->kind) {
  476. case BoolOp_kind:
  477. CALL_SEQ(astfold_expr, expr_ty, node_->v.BoolOp.values);
  478. break;
  479. case BinOp_kind:
  480. CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
  481. CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
  482. CALL(fold_binop, expr_ty, node_);
  483. break;
  484. case UnaryOp_kind:
  485. CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
  486. CALL(fold_unaryop, expr_ty, node_);
  487. break;
  488. case Lambda_kind:
  489. CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
  490. CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
  491. break;
  492. case IfExp_kind:
  493. CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
  494. CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
  495. CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
  496. break;
  497. case Dict_kind:
  498. CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.keys);
  499. CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.values);
  500. break;
  501. case Set_kind:
  502. CALL_SEQ(astfold_expr, expr_ty, node_->v.Set.elts);
  503. break;
  504. case ListComp_kind:
  505. CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
  506. CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.ListComp.generators);
  507. break;
  508. case SetComp_kind:
  509. CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
  510. CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.SetComp.generators);
  511. break;
  512. case DictComp_kind:
  513. CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
  514. CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
  515. CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.DictComp.generators);
  516. break;
  517. case GeneratorExp_kind:
  518. CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
  519. CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.GeneratorExp.generators);
  520. break;
  521. case Await_kind:
  522. CALL(astfold_expr, expr_ty, node_->v.Await.value);
  523. break;
  524. case Yield_kind:
  525. CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
  526. break;
  527. case YieldFrom_kind:
  528. CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
  529. break;
  530. case Compare_kind:
  531. CALL(astfold_expr, expr_ty, node_->v.Compare.left);
  532. CALL_SEQ(astfold_expr, expr_ty, node_->v.Compare.comparators);
  533. CALL(fold_compare, expr_ty, node_);
  534. break;
  535. case Call_kind:
  536. CALL(astfold_expr, expr_ty, node_->v.Call.func);
  537. CALL_SEQ(astfold_expr, expr_ty, node_->v.Call.args);
  538. CALL_SEQ(astfold_keyword, keyword_ty, node_->v.Call.keywords);
  539. break;
  540. case FormattedValue_kind:
  541. CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
  542. CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
  543. break;
  544. case JoinedStr_kind:
  545. CALL_SEQ(astfold_expr, expr_ty, node_->v.JoinedStr.values);
  546. break;
  547. case Attribute_kind:
  548. CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
  549. break;
  550. case Subscript_kind:
  551. CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
  552. CALL(astfold_slice, slice_ty, node_->v.Subscript.slice);
  553. CALL(fold_subscr, expr_ty, node_);
  554. break;
  555. case Starred_kind:
  556. CALL(astfold_expr, expr_ty, node_->v.Starred.value);
  557. break;
  558. case List_kind:
  559. CALL_SEQ(astfold_expr, expr_ty, node_->v.List.elts);
  560. break;
  561. case Tuple_kind:
  562. CALL_SEQ(astfold_expr, expr_ty, node_->v.Tuple.elts);
  563. CALL(fold_tuple, expr_ty, node_);
  564. break;
  565. case Name_kind:
  566. if (_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
  567. return make_const(node_, PyBool_FromLong(!optimize_), ctx_);
  568. }
  569. break;
  570. default:
  571. break;
  572. }
  573. return 1;
  574. }
  575. static int
  576. astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_)
  577. {
  578. switch (node_->kind) {
  579. case Slice_kind:
  580. CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
  581. CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
  582. CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
  583. break;
  584. case ExtSlice_kind:
  585. CALL_SEQ(astfold_slice, slice_ty, node_->v.ExtSlice.dims);
  586. break;
  587. case Index_kind:
  588. CALL(astfold_expr, expr_ty, node_->v.Index.value);
  589. break;
  590. default:
  591. break;
  592. }
  593. return 1;
  594. }
  595. static int
  596. astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_)
  597. {
  598. CALL(astfold_expr, expr_ty, node_->value);
  599. return 1;
  600. }
  601. static int
  602. astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_)
  603. {
  604. CALL(astfold_expr, expr_ty, node_->target);
  605. CALL(astfold_expr, expr_ty, node_->iter);
  606. CALL_SEQ(astfold_expr, expr_ty, node_->ifs);
  607. CALL(fold_iter, expr_ty, node_->iter);
  608. return 1;
  609. }
  610. static int
  611. astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_)
  612. {
  613. CALL_SEQ(astfold_arg, arg_ty, node_->args);
  614. CALL_OPT(astfold_arg, arg_ty, node_->vararg);
  615. CALL_SEQ(astfold_arg, arg_ty, node_->kwonlyargs);
  616. CALL_SEQ(astfold_expr, expr_ty, node_->kw_defaults);
  617. CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
  618. CALL_SEQ(astfold_expr, expr_ty, node_->defaults);
  619. return 1;
  620. }
  621. static int
  622. astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_)
  623. {
  624. CALL_OPT(astfold_expr, expr_ty, node_->annotation);
  625. return 1;
  626. }
  627. static int
  628. astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_)
  629. {
  630. switch (node_->kind) {
  631. case FunctionDef_kind:
  632. CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
  633. CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
  634. CALL_SEQ(astfold_expr, expr_ty, node_->v.FunctionDef.decorator_list);
  635. CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
  636. break;
  637. case AsyncFunctionDef_kind:
  638. CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
  639. CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
  640. CALL_SEQ(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.decorator_list);
  641. CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
  642. break;
  643. case ClassDef_kind:
  644. CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.bases);
  645. CALL_SEQ(astfold_keyword, keyword_ty, node_->v.ClassDef.keywords);
  646. CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
  647. CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.decorator_list);
  648. break;
  649. case Return_kind:
  650. CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
  651. break;
  652. case Delete_kind:
  653. CALL_SEQ(astfold_expr, expr_ty, node_->v.Delete.targets);
  654. break;
  655. case Assign_kind:
  656. CALL_SEQ(astfold_expr, expr_ty, node_->v.Assign.targets);
  657. CALL(astfold_expr, expr_ty, node_->v.Assign.value);
  658. break;
  659. case AugAssign_kind:
  660. CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
  661. CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
  662. break;
  663. case AnnAssign_kind:
  664. CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
  665. CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
  666. CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
  667. break;
  668. case For_kind:
  669. CALL(astfold_expr, expr_ty, node_->v.For.target);
  670. CALL(astfold_expr, expr_ty, node_->v.For.iter);
  671. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.body);
  672. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.orelse);
  673. CALL(fold_iter, expr_ty, node_->v.For.iter);
  674. break;
  675. case AsyncFor_kind:
  676. CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
  677. CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
  678. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.body);
  679. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.orelse);
  680. break;
  681. case While_kind:
  682. CALL(astfold_expr, expr_ty, node_->v.While.test);
  683. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.body);
  684. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.orelse);
  685. break;
  686. case If_kind:
  687. CALL(astfold_expr, expr_ty, node_->v.If.test);
  688. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.body);
  689. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.orelse);
  690. break;
  691. case With_kind:
  692. CALL_SEQ(astfold_withitem, withitem_ty, node_->v.With.items);
  693. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.With.body);
  694. break;
  695. case AsyncWith_kind:
  696. CALL_SEQ(astfold_withitem, withitem_ty, node_->v.AsyncWith.items);
  697. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncWith.body);
  698. break;
  699. case Raise_kind:
  700. CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
  701. CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
  702. break;
  703. case Try_kind:
  704. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.body);
  705. CALL_SEQ(astfold_excepthandler, excepthandler_ty, node_->v.Try.handlers);
  706. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.orelse);
  707. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.finalbody);
  708. break;
  709. case Assert_kind:
  710. CALL(astfold_expr, expr_ty, node_->v.Assert.test);
  711. CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
  712. break;
  713. case Expr_kind:
  714. CALL(astfold_expr, expr_ty, node_->v.Expr.value);
  715. break;
  716. default:
  717. break;
  718. }
  719. return 1;
  720. }
  721. static int
  722. astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_)
  723. {
  724. switch (node_->kind) {
  725. case ExceptHandler_kind:
  726. CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
  727. CALL_SEQ(astfold_stmt, stmt_ty, node_->v.ExceptHandler.body);
  728. break;
  729. default:
  730. break;
  731. }
  732. return 1;
  733. }
  734. static int
  735. astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_)
  736. {
  737. CALL(astfold_expr, expr_ty, node_->context_expr);
  738. CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
  739. return 1;
  740. }
  741. #undef CALL
  742. #undef CALL_OPT
  743. #undef CALL_SEQ
  744. #undef CALL_INT_SEQ
  745. int
  746. _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize)
  747. {
  748. int ret = astfold_mod(mod, arena, optimize);
  749. assert(ret || PyErr_Occurred());
  750. return ret;
  751. }