{-# LANGUAGE CPP #-}
module Crypto.PubKey.ECIES.Conduit
  ( encrypt
  , decrypt
  ) where

import           Control.Monad.Catch                  (MonadThrow, throwM)
import           Control.Monad.Trans.Class            (lift)
import qualified Crypto.Cipher.ChaCha                 as ChaCha
import qualified Crypto.Cipher.ChaChaPoly1305.Conduit as ChaCha
import qualified Crypto.ECC                           as ECC
import qualified Crypto.Error                         as CE
import           Crypto.Hash                          (SHA512 (..), hashWith)
import           Crypto.PubKey.ECIES                  (deriveDecrypt,
                                                       deriveEncrypt)
import           Crypto.Random                        (MonadRandom)
import qualified Data.ByteArray                       as BA
import           Data.ByteString                      (ByteString)
import qualified Data.ByteString                      as B
import qualified Data.ByteString.Lazy                 as BL
import           Data.Conduit                         (ConduitM, yield)
import qualified Data.Conduit.Binary                  as CB
import           Data.Proxy                           (Proxy (..))
import           System.IO.Unsafe                     (unsafePerformIO)

getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString)
getNonceKey :: SharedSecret -> (ByteString, ByteString)
getNonceKey SharedSecret
shared =
  let state1 :: StateSimple
state1 = ByteString -> StateSimple
forall seed. ByteArrayAccess seed => seed -> StateSimple
ChaCha.initializeSimple (ByteString -> StateSimple) -> ByteString -> StateSimple
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Int
40 (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Digest SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Digest SHA512 -> ByteString) -> Digest SHA512 -> ByteString
forall a b. (a -> b) -> a -> b
$ SHA512 -> SharedSecret -> Digest SHA512
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith SHA512
SHA512 SharedSecret
shared
      (ByteString
nonce, StateSimple
state2) = StateSimple -> Int -> (ByteString, StateSimple)
forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
ChaCha.generateSimple StateSimple
state1 Int
12
      (ByteString
key, StateSimple
_) = StateSimple -> Int -> (ByteString, StateSimple)
forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
ChaCha.generateSimple StateSimple
state2 Int
32
   in (ByteString
nonce, ByteString
key)

type Curve = ECC.Curve_P256R1

proxy :: Proxy Curve
proxy :: Proxy Curve
proxy = Proxy Curve
forall {k} (t :: k). Proxy t
Proxy

pointBinarySize :: Int
pointBinarySize :: Int
pointBinarySize = ByteString -> Int
B.length (ByteString -> Int) -> ByteString -> Int
forall a b. (a -> b) -> a -> b
$ Proxy Curve -> Point Curve -> ByteString
forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
forall bs (proxy :: * -> *).
ByteArray bs =>
proxy Curve -> Point Curve -> bs
ECC.encodePoint Proxy Curve
proxy Point Curve
Point
point
  where
    point :: Point
point = IO Point -> Point
forall a. IO a -> a
unsafePerformIO (KeyPair Curve -> Point Curve
KeyPair Curve -> Point
forall curve. KeyPair curve -> Point curve
ECC.keypairGetPublic (KeyPair Curve -> Point) -> IO (KeyPair Curve) -> IO Point
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Proxy Curve -> IO (KeyPair Curve)
forall curve (randomly :: * -> *) (proxy :: * -> *).
(EllipticCurve curve, MonadRandom randomly) =>
proxy curve -> randomly (KeyPair curve)
forall (randomly :: * -> *) (proxy :: * -> *).
MonadRandom randomly =>
proxy Curve -> randomly (KeyPair Curve)
ECC.curveGenerateKeyPair Proxy Curve
proxy)
{-# NOINLINE pointBinarySize #-}

throwOnFail :: MonadThrow m => CE.CryptoFailable a -> m a
throwOnFail :: forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail (CE.CryptoPassed a
a) = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
throwOnFail (CE.CryptoFailed CryptoError
e) = CryptoError -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM CryptoError
e


encrypt
  :: (MonadThrow m, MonadRandom m)
  => ECC.Point Curve
  -> ConduitM ByteString ByteString m ()
encrypt :: forall (m :: * -> *).
(MonadThrow m, MonadRandom m) =>
Point Curve -> ConduitM ByteString ByteString m ()
encrypt Point Curve
point = do
  (point', shared) <- m (CryptoFailable (Point, SharedSecret))
-> ConduitT
     ByteString ByteString m (CryptoFailable (Point, SharedSecret))
forall (m :: * -> *) a.
Monad m =>
m a -> ConduitT ByteString ByteString m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Proxy Curve
-> Point Curve -> m (CryptoFailable (Point Curve, SharedSecret))
forall {randomly :: * -> *} {curve} {proxy :: * -> *}.
(MonadRandom randomly, EllipticCurveDH curve) =>
proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncryptCompat Proxy Curve
proxy Point Curve
point) ConduitT
  ByteString ByteString m (CryptoFailable (Point, SharedSecret))
-> (CryptoFailable (Point, SharedSecret)
    -> ConduitT ByteString ByteString m (Point, SharedSecret))
-> ConduitT ByteString ByteString m (Point, SharedSecret)
forall a b.
ConduitT ByteString ByteString m a
-> (a -> ConduitT ByteString ByteString m b)
-> ConduitT ByteString ByteString m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CryptoFailable (Point, SharedSecret)
-> ConduitT ByteString ByteString m (Point, SharedSecret)
forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail
  let (nonce, key) = getNonceKey shared
  yield $ ECC.encodePoint proxy point'
  ChaCha.encrypt nonce key
  where
#if MIN_VERSION_cryptonite(0,23,999)
    deriveEncryptCompat :: proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncryptCompat proxy curve
prx Point curve
p = proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
forall {randomly :: * -> *} {curve} {proxy :: * -> *}.
(MonadRandom randomly, EllipticCurveDH curve) =>
proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncrypt proxy curve
prx Point curve
p
#else
    deriveEncryptCompat prx p = CE.CryptoPassed <$> deriveEncrypt prx p
#endif

decrypt
  :: (MonadThrow m)
  => ECC.Scalar Curve
  -> ConduitM ByteString ByteString m ()
decrypt :: forall (m :: * -> *).
MonadThrow m =>
Scalar Curve -> ConduitM ByteString ByteString m ()
decrypt Scalar Curve
scalar = do
  pointBS <- (LazyByteString -> ByteString)
-> ConduitT ByteString ByteString m LazyByteString
-> ConduitT ByteString ByteString m ByteString
forall a b.
(a -> b)
-> ConduitT ByteString ByteString m a
-> ConduitT ByteString ByteString m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap LazyByteString -> ByteString
BL.toStrict (ConduitT ByteString ByteString m LazyByteString
 -> ConduitT ByteString ByteString m ByteString)
-> ConduitT ByteString ByteString m LazyByteString
-> ConduitT ByteString ByteString m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ConduitT ByteString ByteString m LazyByteString
forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m LazyByteString
CB.take Int
pointBinarySize
  point   <- throwOnFail (ECC.decodePoint proxy pointBS)
  shared  <- throwOnFail (deriveDecryptCompat proxy point scalar)
  let (_nonce, key) = getNonceKey shared
  ChaCha.decrypt key
  where
#if MIN_VERSION_cryptonite(0,23,999)
    deriveDecryptCompat :: proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecryptCompat proxy curve
prx Point curve
p Scalar curve
s = proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
forall {curve} {proxy :: * -> *}.
EllipticCurveDH curve =>
proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecrypt proxy curve
prx Point curve
p Scalar curve
s
#else
    deriveDecryptCompat prx p s = CE.CryptoPassed (deriveDecrypt prx p s)
#endif