never executed always true always false
    1 module GHC.Core.Opt.Exitify ( exitifyProgram ) where
    2 
    3 {-
    4 Note [Exitification]
    5 ~~~~~~~~~~~~~~~~~~~~
    6 
    7 This module implements Exitification. The goal is to pull as much code out of
    8 recursive functions as possible, as the simplifier is better at inlining into
    9 call-sites that are not in recursive functions.
   10 
   11 Example:
   12 
   13   let t = foo bar
   14   joinrec go 0     x y = t (x*x)
   15           go (n-1) x y = jump go (n-1) (x+y)
   16   in …
   17 
   18 We’d like to inline `t`, but that does not happen: Because t is a thunk and is
   19 used in a recursive function, doing so might lose sharing in general. In
   20 this case, however, `t` is on the _exit path_ of `go`, so called at most once.
   21 How do we make this clearly visible to the simplifier?
   22 
   23 A code path (i.e., an expression in a tail-recursive position) in a recursive
   24 function is an exit path if it does not contain a recursive call. We can bind
   25 this expression outside the recursive function, as a join-point.
   26 
   27 Example result:
   28 
   29   let t = foo bar
   30   join exit x = t (x*x)
   31   joinrec go 0     x y = jump exit x
   32           go (n-1) x y = jump go (n-1) (x+y)
   33   in …
   34 
   35 Now `t` is no longer in a recursive function, and good things happen!
   36 -}
   37 
   38 import GHC.Prelude
   39 import GHC.Types.Var
   40 import GHC.Types.Id
   41 import GHC.Types.Id.Info
   42 import GHC.Core
   43 import GHC.Core.Utils
   44 import GHC.Utils.Monad.State.Strict
   45 import GHC.Builtin.Uniques
   46 import GHC.Types.Var.Set
   47 import GHC.Types.Var.Env
   48 import GHC.Core.FVs
   49 import GHC.Data.FastString
   50 import GHC.Core.Type
   51 import GHC.Utils.Misc( mapSnd )
   52 
   53 import Data.Bifunctor
   54 import Control.Monad
   55 
   56 -- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
   57 -- The really interesting function is exitifyRec
   58 exitifyProgram :: CoreProgram -> CoreProgram
   59 exitifyProgram binds = map goTopLvl binds
   60   where
   61     goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
   62     goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
   63       -- Top-level bindings are never join points
   64 
   65     in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
   66 
   67     go :: InScopeSet -> CoreExpr -> CoreExpr
   68     go _    e@(Var{})       = e
   69     go _    e@(Lit {})      = e
   70     go _    e@(Type {})     = e
   71     go _    e@(Coercion {}) = e
   72     go in_scope (Cast e' c) = Cast (go in_scope e') c
   73     go in_scope (Tick t e') = Tick t (go in_scope e')
   74     go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2)
   75 
   76     go in_scope (Lam v e')
   77       = Lam v (go in_scope' e')
   78       where in_scope' = in_scope `extendInScopeSet` v
   79 
   80     go in_scope (Case scrut bndr ty alts)
   81       = Case (go in_scope scrut) bndr ty (map go_alt alts)
   82       where
   83         in_scope1 = in_scope `extendInScopeSet` bndr
   84         go_alt (Alt dc pats rhs) = Alt dc pats (go in_scope' rhs)
   85            where in_scope' = in_scope1 `extendInScopeSetList` pats
   86 
   87     go in_scope (Let (NonRec bndr rhs) body)
   88       = Let (NonRec bndr (go in_scope rhs)) (go in_scope' body)
   89       where
   90         in_scope' = in_scope `extendInScopeSet` bndr
   91 
   92     go in_scope (Let (Rec pairs) body)
   93       | is_join_rec = mkLets (exitifyRec in_scope' pairs') body'
   94       | otherwise   = Let (Rec pairs') body'
   95       where
   96         is_join_rec = any (isJoinId . fst) pairs
   97         in_scope'   = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
   98         pairs'      = mapSnd (go in_scope') pairs
   99         body'       = go in_scope' body
  100 
  101 
  102 -- | State Monad used inside `exitify`
  103 type ExitifyM =  State [(JoinId, CoreExpr)]
  104 
  105 -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
  106 --   join-points outside the joinrec.
  107 exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
  108 exitifyRec in_scope pairs
  109   = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs']
  110   where
  111     -- We need the set of free variables of many subexpressions here, so
  112     -- annotate the AST with them
  113     -- see Note [Calculating free variables]
  114     ann_pairs = map (second freeVars) pairs
  115 
  116     -- Which are the recursive calls?
  117     recursive_calls = mkVarSet $ map fst pairs
  118 
  119     (pairs',exits) = (`runState` []) $
  120         forM ann_pairs $ \(x,rhs) -> do
  121             -- go past the lambdas of the join point
  122             let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
  123             body' <- go args body
  124             let rhs' = mkLams args body'
  125             return (x, rhs')
  126 
  127     ---------------------
  128     -- 'go' is the main working function.
  129     -- It goes through the RHS (tail-call positions only),
  130     -- checks if there are no more recursive calls, if so, abstracts over
  131     -- variables bound on the way and lifts it out as a join point.
  132     --
  133     -- ExitifyM is a state monad to keep track of floated binds
  134     go :: [Var]           -- ^ Variables that are in-scope here, but
  135                           -- not in scope at the joinrec; that is,
  136                           -- we must potentially abstract over them.
  137                           -- Invariant: they are kept in dependency order
  138        -> CoreExprWithFVs -- ^ Current expression in tail position
  139        -> ExitifyM CoreExpr
  140 
  141     -- We first look at the expression (no matter what it shape is)
  142     -- and determine if we can turn it into a exit join point
  143     go captured ann_e
  144         | -- An exit expression has no recursive calls
  145           let fvs = dVarSetToVarSet (freeVarsOf ann_e)
  146         , disjointVarSet fvs recursive_calls
  147         = go_exit captured (deAnnotate ann_e) fvs
  148 
  149     -- We could not turn it into a exit join point. So now recurse
  150     -- into all expression where eligible exit join points might sit,
  151     -- i.e. into all tail-call positions:
  152 
  153     -- Case right hand sides are in tail-call position
  154     go captured (_, AnnCase scrut bndr ty alts) = do
  155         alts' <- forM alts $ \(AnnAlt dc pats rhs) -> do
  156             rhs' <- go (captured ++ [bndr] ++ pats) rhs
  157             return (Alt dc pats rhs')
  158         return $ Case (deAnnotate scrut) bndr ty alts'
  159 
  160     go captured (_, AnnLet ann_bind body)
  161         -- join point, RHS and body are in tail-call position
  162         | AnnNonRec j rhs <- ann_bind
  163         , Just join_arity <- isJoinId_maybe j
  164         = do let (params, join_body) = collectNAnnBndrs join_arity rhs
  165              join_body' <- go (captured ++ params) join_body
  166              let rhs' = mkLams params join_body'
  167              body' <- go (captured ++ [j]) body
  168              return $ Let (NonRec j rhs') body'
  169 
  170         -- rec join point, RHSs and body are in tail-call position
  171         | AnnRec pairs <- ann_bind
  172         , isJoinId (fst (head pairs))
  173         = do let js = map fst pairs
  174              pairs' <- forM pairs $ \(j,rhs) -> do
  175                  let join_arity = idJoinArity j
  176                      (params, join_body) = collectNAnnBndrs join_arity rhs
  177                  join_body' <- go (captured ++ js ++ params) join_body
  178                  let rhs' = mkLams params join_body'
  179                  return (j, rhs')
  180              body' <- go (captured ++ js) body
  181              return $ Let (Rec pairs') body'
  182 
  183         -- normal Let, only the body is in tail-call position
  184         | otherwise
  185         = do body' <- go (captured ++ bindersOf bind ) body
  186              return $ Let bind body'
  187       where bind = deAnnBind ann_bind
  188 
  189     -- Cannot be turned into an exit join point, but also has no
  190     -- tail-call subexpression. Nothing to do here.
  191     go _ ann_e = return (deAnnotate ann_e)
  192 
  193     ---------------------
  194     go_exit :: [Var]      -- Variables captured locally
  195             -> CoreExpr   -- An exit expression
  196             -> VarSet     -- Free vars of the expression
  197             -> ExitifyM CoreExpr
  198     -- go_exit deals with a tail expression that is floatable
  199     -- out as an exit point; that is, it mentions no recursive calls
  200     go_exit captured e fvs
  201       -- Do not touch an expression that is already a join jump where all arguments
  202       -- are captured variables. See Note [Idempotency]
  203       -- But _do_ float join jumps with interesting arguments.
  204       -- See Note [Jumps can be interesting]
  205       | (Var f, args) <- collectArgs e
  206       , isJoinId f
  207       , all isCapturedVarArg args
  208       = return e
  209 
  210       -- Do not touch a boring expression (see Note [Interesting expression])
  211       | not is_interesting
  212       = return e
  213 
  214       -- Cannot float out if local join points are used, as
  215       -- we cannot abstract over them
  216       | captures_join_points
  217       = return e
  218 
  219       -- We have something to float out!
  220       | otherwise
  221       = do { -- Assemble the RHS of the exit join point
  222              let rhs   = mkLams abs_vars e
  223                  avoid = in_scope `extendInScopeSetList` captured
  224              -- Remember this binding under a suitable name
  225            ; v <- addExit avoid (length abs_vars) rhs
  226              -- And jump to it from here
  227            ; return $ mkVarApps (Var v) abs_vars }
  228 
  229       where
  230         -- Used to detect exit expressions that are already proper exit jumps
  231         isCapturedVarArg (Var v) = v `elem` captured
  232         isCapturedVarArg _ = False
  233 
  234         -- An interesting exit expression has free, non-imported
  235         -- variables from outside the recursive group
  236         -- See Note [Interesting expression]
  237         is_interesting = anyVarSet isLocalId $
  238                          fvs `minusVarSet` mkVarSet captured
  239 
  240         -- The arguments of this exit join point
  241         -- See Note [Picking arguments to abstract over]
  242         abs_vars = snd $ foldr pick (fvs, []) captured
  243           where
  244             pick v (fvs', acc) | v `elemVarSet` fvs' = (fvs' `delVarSet` v, zap v : acc)
  245                                | otherwise           = (fvs',               acc)
  246 
  247         -- We are going to abstract over these variables, so we must
  248         -- zap any IdInfo they have; see #15005
  249         -- cf. GHC.Core.Opt.SetLevels.abstractVars
  250         zap v | isId v = setIdInfo v vanillaIdInfo
  251               | otherwise = v
  252 
  253         -- We cannot abstract over join points
  254         captures_join_points = any isJoinId abs_vars
  255 
  256 
  257 -- Picks a new unique, which is disjoint from
  258 --  * the free variables of the whole joinrec
  259 --  * any bound variables (captured)
  260 --  * any exit join points created so far.
  261 mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
  262 mkExitJoinId in_scope ty join_arity = do
  263     fs <- get
  264     let avoid = in_scope `extendInScopeSetList` (map fst fs)
  265                          `extendInScopeSet` exit_id_tmpl -- just cosmetics
  266     return (uniqAway avoid exit_id_tmpl)
  267   where
  268     exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique Many ty
  269                     `asJoinId` join_arity
  270 
  271 addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
  272 addExit in_scope join_arity rhs = do
  273     -- Pick a suitable name
  274     let ty = exprType rhs
  275     v <- mkExitJoinId in_scope ty join_arity
  276     fs <- get
  277     put ((v,rhs):fs)
  278     return v
  279 
  280 {-
  281 Note [Interesting expression]
  282 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  283 We do not want this to happen:
  284 
  285   joinrec go 0     x y = x
  286           go (n-1) x y = jump go (n-1) (x+y)
  287   in …
  288 ==>
  289   join exit x = x
  290   joinrec go 0     x y = jump exit x
  291           go (n-1) x y = jump go (n-1) (x+y)
  292   in …
  293 
  294 because the floated exit path (`x`) is simply a parameter of `go`; there are
  295 not useful interactions exposed this way.
  296 
  297 Neither do we want this to happen
  298 
  299   joinrec go 0     x y = x+x
  300           go (n-1) x y = jump go (n-1) (x+y)
  301   in …
  302 ==>
  303   join exit x = x+x
  304   joinrec go 0     x y = jump exit x
  305           go (n-1) x y = jump go (n-1) (x+y)
  306   in …
  307 
  308 where the floated expression `x+x` is a bit more complicated, but still not
  309 intersting.
  310 
  311 Expressions are interesting when they move an occurrence of a variable outside
  312 the recursive `go` that can benefit from being obviously called once, for example:
  313  * a local thunk that can then be inlined (see example in note [Exitification])
  314  * the parameter of a function, where the demand analyzer then can then
  315    see that it is called at most once, and hence improve the function’s
  316    strictness signature
  317 
  318 So we only hoist an exit expression out if it mentiones at least one free,
  319 non-imported variable.
  320 
  321 Note [Jumps can be interesting]
  322 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  323 A jump to a join point can be interesting, if its arguments contain free
  324 non-exported variables (z in the following example):
  325 
  326   joinrec go 0     x y = jump j (x+z)
  327           go (n-1) x y = jump go (n-1) (x+y)
  328   in …
  329 ==>
  330   join exit x y = jump j (x+z)
  331   joinrec go 0     x y = jump exit x
  332           go (n-1) x y = jump go (n-1) (x+y)
  333 
  334 
  335 The join point itself can be interesting, even if none if its
  336 arguments have free variables free in the joinrec.  For example
  337 
  338   join j p = case p of (x,y) -> x+y
  339   joinrec go 0     x y = jump j (x,y)
  340           go (n-1) x y = jump go (n-1) (x+y) y
  341   in …
  342 
  343 Here, `j` would not be inlined because we do not inline something that looks
  344 like an exit join point (see Note [Do not inline exit join points]). But
  345 if we exitify the 'jump j (x,y)' we get
  346 
  347   join j p = case p of (x,y) -> x+y
  348   join exit x y = jump j (x,y)
  349   joinrec go 0     x y = jump exit x y
  350           go (n-1) x y = jump go (n-1) (x+y) y
  351   in …
  352 
  353 and now 'j' can inline, and we get rid of the pair. Here's another
  354 example (assume `g` to be an imported function that, on its own,
  355 does not make this interesting):
  356 
  357   join j y = map f y
  358   joinrec go 0     x y = jump j (map g x)
  359           go (n-1) x y = jump go (n-1) (x+y)
  360   in …
  361 
  362 Again, `j` would not be inlined because we do not inline something that looks
  363 like an exit join point (see Note [Do not inline exit join points]).
  364 
  365 But after exitification we have
  366 
  367   join j y = map f y
  368   join exit x = jump j (map g x)
  369   joinrec go 0     x y = jump j (map g x)
  370               go (n-1) x y = jump go (n-1) (x+y)
  371   in …
  372 
  373 and now we can inline `j` and this will allow `map/map` to fire.
  374 
  375 
  376 Note [Idempotency]
  377 ~~~~~~~~~~~~~~~~~~
  378 
  379 We do not want this to happen, where we replace the floated expression with
  380 essentially the same expression:
  381 
  382   join exit x = t (x*x)
  383   joinrec go 0     x y = jump exit x
  384           go (n-1) x y = jump go (n-1) (x+y)
  385   in …
  386 ==>
  387   join exit x = t (x*x)
  388   join exit' x = jump exit x
  389   joinrec go 0     x y = jump exit' x
  390           go (n-1) x y = jump go (n-1) (x+y)
  391   in …
  392 
  393 So when the RHS is a join jump, and all of its arguments are captured variables,
  394 then we leave it in place.
  395 
  396 Note that `jump exit x` in this example looks interesting, as `exit` is a free
  397 variable. Therefore, idempotency does not simply follow from floating only
  398 interesting expressions.
  399 
  400 Note [Calculating free variables]
  401 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  402 We have two options where to annotate the tree with free variables:
  403 
  404  A) The whole tree.
  405  B) Each individual joinrec as we come across it.
  406 
  407 Downside of A: We pay the price on the whole module, even outside any joinrecs.
  408 Downside of B: We pay the price per joinrec, possibly multiple times when
  409 joinrecs are nested.
  410 
  411 Further downside of A: If the exitify function returns annotated expressions,
  412 it would have to ensure that the annotations are correct.
  413 
  414 We therefore choose B, and calculate the free variables in `exitify`.
  415 
  416 
  417 Note [Do not inline exit join points]
  418 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  419 When we have
  420 
  421   let t = foo bar
  422   join exit x = t (x*x)
  423   joinrec go 0     x y = jump exit x
  424           go (n-1) x y = jump go (n-1) (x+y)
  425   in …
  426 
  427 we do not want the simplifier to simply inline `exit` back in (which it happily
  428 would).
  429 
  430 To prevent this, we need to recognize exit join points, and then disable
  431 inlining.
  432 
  433 Exit join points, recognizeable using `isExitJoinId` are join points with an
  434 occurrence in a recursive group, and can be recognized (after the occurrence
  435 analyzer ran!) using `isExitJoinId`.
  436 This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
  437 because the lambdas of a non-recursive join point are not considered for
  438 `occ_in_lam`.  For example, in the following code, `j1` is /not/ marked
  439 occ_in_lam, because `j2` is called only once.
  440 
  441   join j1 x = x+1
  442   join j2 y = join j1 (y+2)
  443 
  444 To prevent inlining, we check for isExitJoinId
  445 * In `preInlineUnconditionally` directly.
  446 * In `simplLetUnfolding` we simply give exit join points no unfolding, which
  447   prevents inlining in `postInlineUnconditionally` and call sites.
  448 
  449 Note [Placement of the exitification pass]
  450 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  451 I (Joachim) experimented with multiple positions for the Exitification pass in
  452 the Core2Core pipeline:
  453 
  454  A) Before the `simpl_phases`
  455  B) Between the `simpl_phases` and the "main" simplifier pass
  456  C) After demand_analyser
  457  D) Before the final simplification phase
  458 
  459 Here is the table (this is without inlining join exit points in the final
  460 simplifier run):
  461 
  462         Program |                       Allocs                      |                      Instrs
  463                 | ABCD.log     A.log     B.log     C.log     D.log  | ABCD.log     A.log     B.log     C.log     D.log
  464 ----------------|---------------------------------------------------|-------------------------------------------------
  465  fannkuch-redux |   -99.9%     +0.0%    -99.9%    -99.9%    -99.9%  |    -3.9%     +0.5%     -3.0%     -3.9%     -3.9%
  466           fasta |    -0.0%     +0.0%     +0.0%     -0.0%     -0.0%  |    -8.5%     +0.0%     +0.0%     -0.0%     -8.5%
  467             fem |     0.0%      0.0%      0.0%      0.0%     +0.0%  |    -2.2%     -0.1%     -0.1%     -2.1%     -2.1%
  468            fish |     0.0%      0.0%      0.0%      0.0%     +0.0%  |    -3.1%     +0.0%     -1.1%     -1.1%     -0.0%
  469    k-nucleotide |   -91.3%    -91.0%    -91.0%    -91.3%    -91.3%  |    -6.3%    +11.4%    +11.4%     -6.3%     -6.2%
  470             scs |    -0.0%     -0.0%     -0.0%     -0.0%     -0.0%  |    -3.4%     -3.0%     -3.1%     -3.3%     -3.3%
  471          simple |    -6.0%      0.0%     -6.0%     -6.0%     +0.0%  |    -3.4%     +0.0%     -5.2%     -3.4%     -0.1%
  472   spectral-norm |    -0.0%      0.0%      0.0%     -0.0%     +0.0%  |    -2.7%     +0.0%     -2.7%     -5.4%     -5.4%
  473 ----------------|---------------------------------------------------|-------------------------------------------------
  474             Min |   -95.0%    -91.0%    -95.0%    -95.0%    -95.0%  |    -8.5%     -3.0%     -5.2%     -6.3%     -8.5%
  475             Max |    +0.2%     +0.2%     +0.2%     +0.2%     +1.5%  |    +0.4%    +11.4%    +11.4%     +0.4%     +1.5%
  476  Geometric Mean |    -4.7%     -2.1%     -4.7%     -4.7%     -4.6%  |    -0.4%     +0.1%     -0.1%     -0.3%     -0.2%
  477 
  478 Position A is disqualified, as it does not get rid of the allocations in
  479 fannkuch-redux.
  480 Position A and B are disqualified because it increases instructions in k-nucleotide.
  481 Positions C and D have their advantages: C decreases allocations in simpl, but D instructions in fasta.
  482 
  483 Assuming we have a budget of _one_ run of Exitification, then C wins (but we
  484 could get more from running it multiple times, as seen in fish).
  485 
  486 Note [Picking arguments to abstract over]
  487 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  488 
  489 When we create an exit join point, so we need to abstract over those of its
  490 free variables that are be out-of-scope at the destination of the exit join
  491 point. So we go through the list `captured` and pick those that are actually
  492 free variables of the join point.
  493 
  494 We do not just `filter (`elemVarSet` fvs) captured`, as there might be
  495 shadowing, and `captured` may contain multiple variables with the same Unique. I
  496 these cases we want to abstract only over the last occurrence, hence the `foldr`
  497 (with emphasis on the `r`). This is #15110.
  498 
  499 -}