summaryrefslogtreecommitdiff
path: root/src/Language/PureScript/TypeChecker/Entailment.hs
blob: be5e2d7ba0c6baf503621df9cdb66ff5d7e5e634 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
-----------------------------------------------------------------------------
--
-- Module      :  Language.PureScript.TypeChecker.Entailment
-- Copyright   :  (c) Phil Freeman 2013
-- License     :  MIT
--
-- Maintainer  :  Phil Freeman <paf31@cantab.net>
-- Stability   :  experimental
-- Portability :
--
-- |
-- Type class entailment
--
-----------------------------------------------------------------------------

module Language.PureScript.TypeChecker.Entailment (
    entails
) where

import Data.Function (on)
import Data.List
import Data.Maybe (maybeToList)
import Data.Foldable (foldMap)
import qualified Data.Map as M

import Control.Applicative
import Control.Arrow (Arrow(..))
import Control.Monad.Except

import Language.PureScript.AST
import Language.PureScript.Errors
import Language.PureScript.Environment
import Language.PureScript.Names
import Language.PureScript.TypeChecker.Monad
import Language.PureScript.TypeChecker.Synonyms
import Language.PureScript.TypeChecker.Unify
import Language.PureScript.TypeClassDictionaries
import Language.PureScript.Types
import qualified Language.PureScript.Constants as C

-- |
-- Check that the current set of type class dictionaries entail the specified type class goal, and, if so,
-- return a type class dictionary reference.
--
entails :: Environment -> ModuleName -> [TypeClassDictionaryInScope] -> Constraint -> Bool -> Check Expr
entails env moduleName context = solve (sortedNubBy canonicalizeDictionary (filter filterModule context))
  where
    sortedNubBy :: (Ord k) => (v -> k) -> [v] -> [v]
    sortedNubBy f vs = M.elems (M.fromList (map (f &&& id) vs))

    -- Filter out type dictionaries which are in scope in the current module
    filterModule :: TypeClassDictionaryInScope -> Bool
    filterModule (TypeClassDictionaryInScope { tcdName = Qualified (Just mn) _ }) | mn == moduleName = True
    filterModule (TypeClassDictionaryInScope { tcdName = Qualified Nothing _ }) = True
    filterModule _ = False

    solve context' (className, tys) trySuperclasses =
      checkOverlaps $ go trySuperclasses className tys
      where
      go trySuperclasses' className' tys' =
        -- Look for regular type instances
        [ mkDictionary (canonicalizeDictionary tcd) args
        | tcd <- context'
        -- Make sure the type class name matches the one we are trying to satisfy
        , className' == tcdClassName tcd
        -- Make sure the type unifies with the type in the type instance definition
        , subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' (tcdInstanceTypes tcd)
        -- Solve any necessary subgoals
        , args <- solveSubgoals subst (tcdDependencies tcd) ] ++

        -- Look for implementations via superclasses
        [ SubclassDictionaryValue suDict superclass index
        | trySuperclasses'
        , (subclassName, (args, _, implies)) <- M.toList (typeClasses env)
        -- Try each superclass
        , (index, (superclass, suTyArgs)) <- zip [0..] implies
        -- Make sure the type class name matches the superclass name
        , className' == superclass
        -- Make sure the types unify with the types in the superclass implication
        , subst <- maybeToList . (>>= verifySubstitution) . fmap concat $ zipWithM (typeHeadsAreEqual moduleName env) tys' suTyArgs
        -- Finally, satisfy the subclass constraint
        , args' <- maybeToList $ mapM ((`lookup` subst) . fst) args
        , suDict <- go True subclassName args' ]

      -- Create dictionaries for subgoals which still need to be solved by calling go recursively
      -- E.g. the goal (Show a, Show b) => Show (Either a b) can be satisfied if the current type
      -- unifies with Either a b, and we can satisfy the subgoals Show a and Show b recursively.
      solveSubgoals :: [(String, Type)] -> Maybe [Constraint] -> [Maybe [DictionaryValue]]
      solveSubgoals _ Nothing = return Nothing
      solveSubgoals subst (Just subgoals) = do
        dict <- mapM (uncurry (go True) . second (map (replaceAllTypeVars subst))) subgoals
        return $ Just dict
      -- Make a dictionary from subgoal dictionaries by applying the correct function
      mkDictionary :: Qualified Ident -> Maybe [DictionaryValue] -> DictionaryValue
      mkDictionary fnName Nothing = LocalDictionaryValue fnName
      mkDictionary fnName (Just []) = GlobalDictionaryValue fnName
      mkDictionary fnName (Just dicts) = DependentDictionaryValue fnName dicts
      -- Turn a DictionaryValue into a Expr
      dictionaryValueToValue :: DictionaryValue -> Expr
      dictionaryValueToValue (LocalDictionaryValue fnName) = Var fnName
      dictionaryValueToValue (GlobalDictionaryValue fnName) = Var fnName
      dictionaryValueToValue (DependentDictionaryValue fnName dicts) = foldl App (Var fnName) (map dictionaryValueToValue dicts)
      dictionaryValueToValue (SubclassDictionaryValue dict superclassName index) =
        App (Accessor (C.__superclass_ ++ show superclassName ++ "_" ++ show index)
                      (dictionaryValueToValue dict))
            valUndefined
      -- Ensure that a substitution is valid
      verifySubstitution :: [(String, Type)] -> Maybe [(String, Type)]
      verifySubstitution subst = do
        let grps = groupBy ((==) `on` fst) subst
        guard (all (pairwise (unifiesWith env) . map snd) grps)
        return $ map head grps
      -- |
      -- Check for overlapping instances
      --
      checkOverlaps :: [DictionaryValue] -> Check Expr
      checkOverlaps dicts =
        case [ (d1, d2) | d1 <- dicts, d2 <- dicts, d1 `overlapping` d2 ] of
          ds@(_ : _) -> throwError . errorMessage $ OverlappingInstances className tys $ nub (map fst ds)
          _ -> case chooseSimplestDictionaries dicts of
                 [] -> throwError . errorMessage $ NoInstanceFound className tys
                 d : _ -> return $ dictionaryValueToValue d
      -- Choose the simplest DictionaryValues from a list of candidates
      -- The reason for this function is as follows:
      -- When considering overlapping instances, we don't want to consider the same dictionary
      -- to be an overlap of itself when obtained as a superclass of another class.
      -- Observing that we probably don't want to select a superclass instance when an instance
      -- is available directly, and that there is no way for a superclass instance to actually
      -- introduce an overlap that wouldn't have been there already, we simply remove dictionaries
      -- obtained as superclass instances if there are simpler instances available.
      chooseSimplestDictionaries :: [DictionaryValue] -> [DictionaryValue]
      chooseSimplestDictionaries ds = case filter isSimpleDictionaryValue ds of
                                        [] -> ds
                                        simple -> simple
      isSimpleDictionaryValue SubclassDictionaryValue{} = False
      isSimpleDictionaryValue (DependentDictionaryValue _ ds) = all isSimpleDictionaryValue ds
      isSimpleDictionaryValue _ = True
      -- |
      -- Check if two dictionaries are overlapping
      --
      -- Dictionaries which are subclass dictionaries cannot overlap, since otherwise the overlap would have
      -- been caught when constructing superclass dictionaries.
      --
      overlapping :: DictionaryValue -> DictionaryValue -> Bool
      overlapping (LocalDictionaryValue nm1)         (LocalDictionaryValue nm2)  | nm1 == nm2 = False
      overlapping (GlobalDictionaryValue nm1)        (GlobalDictionaryValue nm2) | nm1 == nm2 = False
      overlapping (DependentDictionaryValue nm1 ds1) (DependentDictionaryValue nm2 ds2)
        | nm1 == nm2 = or $ zipWith overlapping ds1 ds2
      overlapping SubclassDictionaryValue{} _ = False
      overlapping _ SubclassDictionaryValue{} = False
      overlapping _ _ = True

    valUndefined :: Expr
    valUndefined = Var (Qualified (Just (ModuleName [ProperName C.prim])) (Ident C.undefined))

-- |
-- Check whether the type heads of two types are equal (for the purposes of type class dictionary lookup),
-- and return a substitution from type variables to types which makes the type heads unify.
--
typeHeadsAreEqual :: ModuleName -> Environment -> Type -> Type -> Maybe [(String, Type)]
typeHeadsAreEqual _ _ (Skolem _ s1 _)      (Skolem _ s2 _)      | s1 == s2 = Just []
typeHeadsAreEqual _ _ t                    (TypeVar v)                     = Just [(v, t)]
typeHeadsAreEqual _ _ (TypeConstructor c1) (TypeConstructor c2) | c1 == c2 = Just []
typeHeadsAreEqual m e (TypeApp h1 t1)      (TypeApp h2 t2)                 = (++) <$> typeHeadsAreEqual m e h1 h2 
                                                                                  <*> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual m e (SaturatedTypeSynonym name args) t2 = case expandTypeSynonym' e name args of
  Left  _  -> Nothing
  Right t1 -> typeHeadsAreEqual m e t1 t2
typeHeadsAreEqual _ _ REmpty REmpty = Just []
typeHeadsAreEqual m e r1@(RCons _ _ _) r2@(RCons _ _ _) =
  let (s1, r1') = rowToList r1
      (s2, r2') = rowToList r2
      
      int = [ (t1, t2) | (name, t1) <- s1, (name', t2) <- s2, name == name' ]
      sd1 = [ (name, t1) | (name, t1) <- s1, name `notElem` map fst s2 ]
      sd2 = [ (name, t2) | (name, t2) <- s2, name `notElem` map fst s1 ]
  in (++) <$> foldMap (\(t1, t2) -> typeHeadsAreEqual m e t1 t2) int 
          <*> go sd1 r1' sd2 r2'
  where
  go :: [(String, Type)] -> Type -> [(String, Type)] -> Type -> Maybe [(String, Type)]
  go [] REmpty          [] REmpty          = Just [] 
  go [] (TUnknown _)    _  _               = Just [] 
  go [] (TypeVar v1)    [] (TypeVar v2)    | v1 == v2 = Just []
  go [] (Skolem _ s1 _) [] (Skolem _ s2 _) | s1 == s2 = Just []
  go sd r               [] (TypeVar v)     = Just [(v, rowFromList (sd, r))]
  go _  _               _  _               = Nothing
typeHeadsAreEqual _ _ _ _ = Nothing

-- |
-- Check all values in a list pairwise match a predicate
--
pairwise :: (a -> a -> Bool) -> [a] -> Bool
pairwise _ [] = True
pairwise _ [_] = True
pairwise p (x : xs) = all (p x) xs && pairwise p xs