never executed always true always false
    1 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
    2 
    3 module Conjure.Language.Instantiate
    4     ( instantiateExpression
    5     , instantiateDomain
    6     , trySimplify
    7     , entailed
    8     ) where
    9 
   10 -- conjure
   11 import Conjure.Prelude
   12 import Conjure.Bug
   13 import Conjure.UserError
   14 import Conjure.Language.Definition
   15 import Conjure.Language.Expression.Op
   16 import Conjure.Language.Domain
   17 import Conjure.Language.Constant
   18 import Conjure.Language.Type
   19 import Conjure.Language.TypeOf
   20 import Conjure.Language.Pretty
   21 import Conjure.Language.EvaluateOp ( evaluateOp )
   22 import Conjure.Process.Enumerate ( EnumerateDomain, enumerateDomain, enumerateInConstant )
   23 
   24 
   25 -- | Try to simplify an expression recursively.
   26 trySimplify ::
   27     MonadUserError m =>
   28     EnumerateDomain m =>
   29     NameGen m =>
   30     (?typeCheckerMode :: TypeCheckerMode) =>
   31     [(Name, Expression)] -> Expression -> m Expression
   32 trySimplify ctxt x = do
   33     res <- runMaybeT $ instantiateExpression ctxt x
   34     case res of
   35         Just c                                                  -- if the expression can be evaluated into a Constant
   36             | null [() | ConstantUndefined{} <- universe c]     -- and if it doesn't contain undefined's in it
   37             -> return (Constant c)                              -- evaluate to the constant
   38         _   -> descendM (trySimplify ctxt) x                    -- otherwise, try the same on its children
   39 
   40 
   41 instantiateExpression ::
   42     MonadFailDoc m =>
   43     EnumerateDomain m =>
   44     NameGen m =>
   45     (?typeCheckerMode :: TypeCheckerMode) =>
   46     [(Name, Expression)] -> Expression -> m Constant
   47 instantiateExpression ctxt x = do
   48     constant <- normaliseConstant <$> evalStateT (instantiateE x) ctxt
   49     case (emptyCollection constant, constant) of
   50         (_, TypedConstant{}) -> return constant
   51         (True, _) -> do
   52             ty <- typeOf x
   53             return (TypedConstant constant ty)
   54         (False, _) -> return constant
   55 
   56 
   57 instantiateDomain ::
   58     MonadFailDoc m =>
   59     EnumerateDomain m =>
   60     NameGen m =>
   61     Pretty r =>
   62     Default r =>
   63     (?typeCheckerMode :: TypeCheckerMode) =>
   64     [(Name, Expression)] -> Domain r Expression -> m (Domain r Constant)
   65 instantiateDomain ctxt x = normaliseDomain normaliseConstant <$> evalStateT (instantiateD x) ctxt
   66 
   67 
   68 newtype HasUndef = HasUndef Any
   69     deriving (Semigroup, Monoid)
   70 
   71 instantiateE ::
   72     MonadFailDoc m =>
   73     MonadState [(Name, Expression)] m =>
   74     EnumerateDomain m =>
   75     NameGen m =>
   76     (?typeCheckerMode :: TypeCheckerMode) =>
   77     Expression -> m Constant
   78 instantiateE (Comprehension body gensOrConds) = do
   79     let
   80         loop ::
   81             MonadFailDoc m =>
   82             MonadState [(Name, Expression)] m =>
   83             EnumerateDomain m =>
   84             NameGen m =>
   85             [GeneratorOrCondition] -> WriterT HasUndef m [Constant]
   86         loop [] = return <$> instantiateE body
   87         loop (Generator (GenDomainNoRepr pat domain) : rest) = do
   88             DomainInConstant domainConstant <- instantiateE (Domain domain)
   89             let undefinedsInsideTheDomain =
   90                     [ und
   91                     | und@ConstantUndefined{} <- universeBi domainConstant
   92                     ]
   93             if null undefinedsInsideTheDomain
   94                 then do
   95                     enumeration <- enumerateDomain domainConstant
   96                     concatMapM
   97                         (\ val -> scope $ do
   98                             valid <- bind pat val
   99                             if valid
  100                                 then loop rest
  101                                 else return [] )
  102                         enumeration
  103                 else do
  104                     tell (HasUndef (Any True))
  105                     return []
  106         loop (Generator (GenDomainHasRepr pat domain) : rest) =
  107             loop (Generator (GenDomainNoRepr (Single pat) (forgetRepr domain)) : rest)
  108         loop (Generator (GenInExpr pat expr) : rest) = do
  109             exprConstant <- instantiateE expr
  110             enumeration <- enumerateInConstant exprConstant
  111             concatMapM
  112                 (\ val -> scope $ do
  113                     valid <- bind pat val
  114                     if valid
  115                         then loop rest
  116                         else return [] )
  117                 enumeration
  118         loop (Condition expr : rest) = do
  119             constant <- instantiateE expr
  120             if constant == ConstantBool True
  121                 then loop rest
  122                 else return []
  123         loop (ComprehensionLetting pat expr : rest) = do
  124             constant <- instantiateE expr
  125             valid <- bind pat constant
  126             unless valid (bug "ComprehensionLetting.bind expected to be valid")
  127             loop rest
  128 
  129 
  130     (constants, HasUndef (Any undefinedsInsideGeneratorDomains)) <- runWriterT (loop gensOrConds)
  131     if undefinedsInsideGeneratorDomains
  132         then do
  133             ty <- typeOf (Comprehension body gensOrConds)
  134             return $ ConstantUndefined
  135                 "Comprehension contains undefined values inside generator domains."
  136                 ty
  137         else
  138             return $ fromList constants
  139 
  140 instantiateE (Reference name (Just (RecordField _ ty))) = return $ ConstantField name ty
  141 instantiateE (Reference name (Just (VariantField _ ty))) = return $ ConstantField name ty
  142 instantiateE (Reference name refto) = do
  143     ctxt <- gets id
  144     case name `lookup` ctxt of
  145         Just x -> instantiateE x
  146         Nothing ->
  147             case refto of
  148                 Just (Alias x) ->
  149                     -- we do not have this name in context, but we have it stored in the Reference itself
  150                     -- reuse that
  151                     instantiateE x
  152                 _ -> 
  153                     failDoc $ vcat
  154                     $ ("No value for:" <+> pretty name)
  155                     : "Bindings in context:"
  156                     : prettyContext ctxt
  157 
  158 instantiateE (Constant c) = return c
  159 instantiateE (AbstractLiteral lit) = instantiateAbsLit lit
  160 instantiateE (Typed x ty) = TypedConstant <$> instantiateE x <*> pure ty
  161 instantiateE (Op op) = instantiateOp op
  162 
  163 -- "Domain () Expression"s inside expressions are handled specially
  164 instantiateE (Domain (DomainReference _ (Just d))) = instantiateE (Domain d)
  165 instantiateE (Domain (DomainReference name Nothing)) = do
  166     ctxt <- gets id
  167     case name `lookup` ctxt of
  168         Just (Domain d) -> instantiateE (Domain d)
  169         _ -> failDoc $ vcat
  170             $ ("No value for:" <+> pretty name)
  171             : "Bindings in context:"
  172             : prettyContext ctxt
  173 instantiateE (Domain domain) = DomainInConstant <$> instantiateD domain
  174 
  175 instantiateE (WithLocals b (AuxiliaryVars locals)) = do
  176     forM_ locals $ \ local -> case local of
  177         SuchThat xs -> forM_ xs $ \ x -> do
  178             constant <- instantiateE x
  179             case constant of
  180                 ConstantBool True -> return ()
  181                 _                 -> failDoc $ "local:" <+> pretty constant
  182         _ -> failDoc $ "local:" <+> pretty local
  183     instantiateE b
  184 
  185 instantiateE (WithLocals b (DefinednessConstraints locals)) = do
  186     forM_ locals $ \ x -> do
  187             constant <- instantiateE x
  188             case constant of
  189                 ConstantBool True -> return ()
  190                 _                 -> failDoc $ "local:" <+> pretty constant
  191     instantiateE b
  192 
  193 instantiateE x = failDoc $ "instantiateE:" <+> pretty (show x)
  194 
  195 
  196 instantiateOp ::
  197     MonadFailDoc m =>
  198     MonadState [(Name, Expression)] m =>
  199     EnumerateDomain m =>
  200     NameGen m =>
  201     (?typeCheckerMode :: TypeCheckerMode) =>
  202     Op Expression -> m Constant
  203 instantiateOp opx = mapM instantiateE opx >>= evaluateOp . fmap normaliseConstant
  204 
  205 
  206 instantiateAbsLit ::
  207     MonadFailDoc m =>
  208     MonadState [(Name, Expression)] m =>
  209     EnumerateDomain m =>
  210     NameGen m =>
  211     (?typeCheckerMode :: TypeCheckerMode) =>
  212     AbstractLiteral Expression -> m Constant
  213 instantiateAbsLit x = do
  214     c <- mapM instantiateE x
  215     case c of
  216         -- for functions, if the same thing is mapped to multiple values, the result is undefined
  217         AbsLitFunction vals -> do
  218             let nubVals = sortNub vals
  219             if length (sortNub (map fst nubVals)) == length nubVals
  220                 then return $ ConstantAbstract $ AbsLitFunction nubVals
  221                 else do
  222                     ty <- typeOf c
  223                     return $ ConstantUndefined "Multiple mappings for the same value." ty
  224         _ -> return $ ConstantAbstract c
  225 
  226 
  227 instantiateD ::
  228     MonadFailDoc m =>
  229     MonadState [(Name, Expression)] m =>
  230     EnumerateDomain m =>
  231     NameGen m =>
  232     Pretty r =>
  233     Default r =>
  234     (?typeCheckerMode :: TypeCheckerMode) =>
  235     Domain r Expression -> m (Domain r Constant)
  236 instantiateD (DomainAny t ty) = return (DomainAny t ty)
  237 instantiateD DomainBool = return DomainBool
  238 instantiateD (DomainIntE maybe_tag x) = do
  239     x' <- instantiateE x
  240     let vals = case (x', viewConstantMatrix x', viewConstantSet x') of
  241                 (ConstantInt{}, _, _) -> [x']
  242                 (_, Just (_, xs), _) -> xs
  243                 (_, _, Just xs) -> xs
  244                 _ -> []
  245     return (DomainInt maybe_tag (map RangeSingle vals))
  246 instantiateD (DomainInt t ranges) = DomainInt t <$> mapM instantiateR ranges
  247 instantiateD (DomainEnum nm Nothing _) = do
  248     st <- gets id
  249     case lookup nm st of
  250         Just (Domain dom) -> instantiateD (defRepr dom)
  251         Just _  -> failDoc $ ("DomainEnum not found in state, Just:" <+> pretty nm) <++> vcat (map pretty st)
  252         Nothing -> failDoc $ ("DomainEnum not found in state, Nothing:" <+> pretty nm) <++> vcat (map pretty st)
  253 instantiateD (DomainEnum nm rs0 _) = do
  254     let fmap4 = fmap . fmap . fmap . fmap
  255     let e2c' x = either bug id (e2c x)
  256     rs <- transformBiM (fmap Constant . instantiateE ) (rs0 :: Maybe [Range Expression])
  257                 |> fmap4 e2c'
  258     st <- gets id
  259     mp <- forM (universeBi rs :: [Name]) $ \ n -> case lookup n st of
  260             Just (Constant (ConstantInt _ i)) -> return (n, i)
  261             Nothing -> failDoc $ "No value for member of enum domain:" <+> pretty n
  262             Just c  -> failDoc $ vcat [ "Incompatible value for member of enum domain:" <+> pretty nm
  263                                    , "    Looking up for member:" <+> pretty n
  264                                    , "    Expected an integer, but got:" <+> pretty c
  265                                    ]
  266     return (DomainEnum nm (rs :: Maybe [Range Constant]) (Just mp))
  267 instantiateD (DomainUnnamed nm s) = DomainUnnamed nm <$> instantiateE s
  268 instantiateD (DomainTuple inners) = DomainTuple <$> mapM instantiateD inners
  269 instantiateD (DomainRecord  inners) = DomainRecord  <$> sequence [ do d' <- instantiateD d ; return (n,d')
  270                                                                  | (n,d) <- inners ]
  271 instantiateD (DomainVariant inners) = DomainVariant <$> sequence [ do d' <- instantiateD d ; return (n,d')
  272                                                                  | (n,d) <- inners ]
  273 instantiateD (DomainMatrix index inner) = DomainMatrix <$> instantiateD index <*> instantiateD inner
  274 instantiateD (DomainSet       r attrs inner) = DomainSet r <$> instantiateSetAttr attrs <*> instantiateD inner
  275 instantiateD (DomainMSet      r attrs inner) = DomainMSet r <$> instantiateMSetAttr attrs <*> instantiateD inner
  276 instantiateD (DomainFunction  r attrs innerFr innerTo) = DomainFunction r <$> instantiateFunctionAttr attrs <*> instantiateD innerFr <*> instantiateD innerTo
  277 instantiateD (DomainSequence  r attrs inner) = DomainSequence r <$> instantiateSequenceAttr attrs <*> instantiateD inner
  278 instantiateD (DomainRelation  r attrs inners) = DomainRelation r <$> instantiateRelationAttr attrs <*> mapM instantiateD inners
  279 instantiateD (DomainPartition r attrs inner) = DomainPartition r <$> instantiatePartitionAttr attrs <*> instantiateD inner
  280 instantiateD (DomainPermutation r attrs inner) = DomainPermutation r <$> instantiatePermutationAttr attrs <*> instantiateD inner
  281 instantiateD (DomainOp nm ds) = DomainOp nm <$> mapM instantiateD ds
  282 instantiateD (DomainReference _ (Just d)) = instantiateD d
  283 instantiateD (DomainReference name Nothing) = do
  284     ctxt <- gets id
  285     case name `lookup` ctxt of
  286         Just (Domain d) -> instantiateD (defRepr d)
  287         _ -> failDoc $ vcat
  288             $ ("No value for:" <+> pretty name)
  289             : "Bindings in context:"
  290             : prettyContext ctxt
  291 instantiateD DomainMetaVar{} = bug "instantiateD DomainMetaVar"
  292 
  293 
  294 instantiateSetAttr ::
  295     MonadFailDoc m =>
  296     MonadState [(Name, Expression)] m =>
  297     EnumerateDomain m =>
  298     NameGen m =>
  299     (?typeCheckerMode :: TypeCheckerMode) =>
  300     SetAttr Expression -> m (SetAttr Constant)
  301 instantiateSetAttr (SetAttr s) = SetAttr <$> instantiateSizeAttr s
  302 
  303 
  304 instantiateSizeAttr ::
  305     MonadFailDoc m =>
  306     MonadState [(Name, Expression)] m =>
  307     EnumerateDomain m =>
  308     NameGen m =>
  309     (?typeCheckerMode :: TypeCheckerMode) =>
  310     SizeAttr Expression -> m (SizeAttr Constant)
  311 instantiateSizeAttr SizeAttr_None = return SizeAttr_None
  312 instantiateSizeAttr (SizeAttr_Size x) = SizeAttr_Size <$> instantiateE x
  313 instantiateSizeAttr (SizeAttr_MinSize x) = SizeAttr_MinSize <$> instantiateE x
  314 instantiateSizeAttr (SizeAttr_MaxSize x) = SizeAttr_MaxSize <$> instantiateE x
  315 instantiateSizeAttr (SizeAttr_MinMaxSize x y) = SizeAttr_MinMaxSize <$> instantiateE x <*> instantiateE y
  316 
  317 
  318 instantiateMSetAttr ::
  319     MonadFailDoc m =>
  320     MonadState [(Name, Expression)] m =>
  321     EnumerateDomain m =>
  322     NameGen m =>
  323     (?typeCheckerMode :: TypeCheckerMode) =>
  324     MSetAttr Expression -> m (MSetAttr Constant)
  325 instantiateMSetAttr (MSetAttr s o) = MSetAttr <$> instantiateSizeAttr s <*> instantiateOccurAttr o
  326 
  327 
  328 instantiateOccurAttr ::
  329     MonadFailDoc m =>
  330     MonadState [(Name, Expression)] m =>
  331     EnumerateDomain m =>
  332     NameGen m =>
  333     (?typeCheckerMode :: TypeCheckerMode) =>
  334     OccurAttr Expression -> m (OccurAttr Constant)
  335 instantiateOccurAttr OccurAttr_None = return OccurAttr_None
  336 instantiateOccurAttr (OccurAttr_MinOccur x) = OccurAttr_MinOccur <$> instantiateE x
  337 instantiateOccurAttr (OccurAttr_MaxOccur x) = OccurAttr_MaxOccur <$> instantiateE x
  338 instantiateOccurAttr (OccurAttr_MinMaxOccur x y) = OccurAttr_MinMaxOccur <$> instantiateE x <*> instantiateE y
  339 
  340 
  341 instantiateFunctionAttr ::
  342     MonadFailDoc m =>
  343     MonadState [(Name, Expression)] m =>
  344     EnumerateDomain m =>
  345     NameGen m =>
  346     (?typeCheckerMode :: TypeCheckerMode) =>
  347     FunctionAttr Expression -> m (FunctionAttr Constant)
  348 instantiateFunctionAttr (FunctionAttr s p j) =
  349     FunctionAttr <$> instantiateSizeAttr s
  350                  <*> pure p
  351                  <*> pure j
  352 
  353 
  354 instantiateSequenceAttr ::
  355     MonadFailDoc m =>
  356     MonadUserError m =>
  357     MonadState [(Name, Expression)] m =>
  358     EnumerateDomain m =>
  359     NameGen m =>
  360     (?typeCheckerMode :: TypeCheckerMode) =>
  361     SequenceAttr Expression -> m (SequenceAttr Constant)
  362 instantiateSequenceAttr (SequenceAttr s j) =
  363     SequenceAttr <$> instantiateSizeAttr s
  364                  <*> pure j
  365 
  366 
  367 instantiateRelationAttr ::
  368     MonadFailDoc m =>
  369     MonadUserError m =>
  370     MonadState [(Name, Expression)] m =>
  371     EnumerateDomain m =>
  372     NameGen m =>
  373     (?typeCheckerMode :: TypeCheckerMode) =>
  374     RelationAttr Expression -> m (RelationAttr Constant)
  375 instantiateRelationAttr (RelationAttr s b) = RelationAttr <$> instantiateSizeAttr s <*> pure b
  376 
  377 
  378 instantiatePartitionAttr ::
  379     MonadFailDoc m =>
  380     MonadUserError m =>
  381     MonadState [(Name, Expression)] m =>
  382     EnumerateDomain m =>
  383     NameGen m =>
  384     (?typeCheckerMode :: TypeCheckerMode) =>
  385     PartitionAttr Expression -> m (PartitionAttr Constant)
  386 instantiatePartitionAttr (PartitionAttr a b r) =
  387     PartitionAttr <$> instantiateSizeAttr a
  388                   <*> instantiateSizeAttr b
  389                   <*> pure r
  390 
  391 
  392 instantiatePermutationAttr ::
  393     MonadFailDoc m =>
  394     MonadState [(Name, Expression)] m =>
  395     EnumerateDomain m =>
  396     NameGen m =>
  397     (?typeCheckerMode :: TypeCheckerMode) =>
  398     PermutationAttr Expression -> m (PermutationAttr Constant)
  399 instantiatePermutationAttr (PermutationAttr x) = PermutationAttr <$> instantiateSizeAttr x
  400 
  401 
  402 instantiateR ::
  403     MonadFailDoc m =>
  404     MonadState [(Name, Expression)] m =>
  405     EnumerateDomain m =>
  406     NameGen m =>
  407     (?typeCheckerMode :: TypeCheckerMode) =>
  408     Range Expression -> m (Range Constant)
  409 instantiateR RangeOpen = return RangeOpen
  410 instantiateR (RangeSingle x) = RangeSingle <$> instantiateE x
  411 instantiateR (RangeLowerBounded x) = RangeLowerBounded <$> instantiateE x
  412 instantiateR (RangeUpperBounded x) = RangeUpperBounded <$> instantiateE x
  413 instantiateR (RangeBounded x y) = RangeBounded <$> instantiateE x <*> instantiateE y
  414 
  415 
  416 bind :: (Functor m, MonadState [(Name, Expression)] m)
  417     => AbstractPattern
  418     -> Constant
  419     -> m Bool -- False means skip
  420 bind (Single nm) val = modify ((nm, Constant val) :) >> return True
  421 bind (AbsPatTuple pats) (viewConstantTuple -> Just vals)
  422     | length pats == length vals = and <$> zipWithM bind pats vals
  423 bind (AbsPatMatrix pats) (viewConstantMatrix -> Just (_, vals))
  424     | length pats == length vals = and <$> zipWithM bind pats vals
  425 bind (AbsPatSet pats) (viewConstantSet -> Just vals)
  426     | length pats == length vals = and <$> zipWithM bind pats vals
  427     | otherwise                  = return False
  428 bind pat val = bug $ "Instantiate.bind:" <++> vcat ["pat:" <+> pretty pat, "val:" <+> pretty val]
  429 
  430 
  431 -- check if the given expression can be evaluated to True
  432 -- False means it is not entailed, as opposed to "it is known to be false"
  433 entailed ::
  434     MonadUserError m =>
  435     EnumerateDomain m =>
  436     NameGen m =>
  437     (?typeCheckerMode :: TypeCheckerMode) =>
  438     Expression -> m Bool
  439 entailed x = do
  440     -- traceM $ show $ "entailed x:" <+> pretty x
  441     c <- trySimplify [] x
  442     -- traceM $ show $ "entailed c:" <+> pretty c
  443     case c of
  444         Constant (ConstantBool True) -> return True
  445         _                            -> return False
  446