{-# LANGUAGE CPP #-} module Crypto.PubKey.ECIES.Conduit ( encrypt , decrypt ) where import Control.Monad.Catch (MonadThrow, throwM) import Control.Monad.Trans.Class (lift) import qualified Crypto.Cipher.ChaCha as ChaCha import qualified Crypto.Cipher.ChaChaPoly1305.Conduit as ChaCha import qualified Crypto.ECC as ECC import qualified Crypto.Error as CE import Crypto.Hash (SHA512 (..), hashWith) import Crypto.PubKey.ECIES (deriveDecrypt, deriveEncrypt) import Crypto.Random (MonadRandom) import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import Data.Conduit (ConduitM, yield) import qualified Data.Conduit.Binary as CB import Data.Proxy (Proxy (..)) import System.IO.Unsafe (unsafePerformIO) getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString) getNonceKey :: SharedSecret -> (ByteString, ByteString) getNonceKey SharedSecret shared = let state1 :: StateSimple state1 = ByteString -> StateSimple forall seed. ByteArrayAccess seed => seed -> StateSimple ChaCha.initializeSimple (ByteString -> StateSimple) -> ByteString -> StateSimple forall a b. (a -> b) -> a -> b $ Int -> ByteString -> ByteString B.take Int 40 (ByteString -> ByteString) -> ByteString -> ByteString forall a b. (a -> b) -> a -> b $ Digest SHA512 -> ByteString forall bin bout. (ByteArrayAccess bin, ByteArray bout) => bin -> bout BA.convert (Digest SHA512 -> ByteString) -> Digest SHA512 -> ByteString forall a b. (a -> b) -> a -> b $ SHA512 -> SharedSecret -> Digest SHA512 forall ba alg. (ByteArrayAccess ba, HashAlgorithm alg) => alg -> ba -> Digest alg hashWith SHA512 SHA512 SharedSecret shared (ByteString nonce, StateSimple state2) = StateSimple -> Int -> (ByteString, StateSimple) forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple) ChaCha.generateSimple StateSimple state1 Int 12 (ByteString key, StateSimple _) = StateSimple -> Int -> (ByteString, StateSimple) forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple) ChaCha.generateSimple StateSimple state2 Int 32 in (ByteString nonce, ByteString key) type Curve = ECC.Curve_P256R1 proxy :: Proxy Curve proxy :: Proxy Curve proxy = Proxy Curve forall {k} (t :: k). Proxy t Proxy pointBinarySize :: Int pointBinarySize :: Int pointBinarySize = ByteString -> Int B.length (ByteString -> Int) -> ByteString -> Int forall a b. (a -> b) -> a -> b $ Proxy Curve -> Point Curve -> ByteString forall curve bs (proxy :: * -> *). (EllipticCurve curve, ByteArray bs) => proxy curve -> Point curve -> bs forall bs (proxy :: * -> *). ByteArray bs => proxy Curve -> Point Curve -> bs ECC.encodePoint Proxy Curve proxy Point Curve Point point where point :: Point point = IO Point -> Point forall a. IO a -> a unsafePerformIO (KeyPair Curve -> Point Curve KeyPair Curve -> Point forall curve. KeyPair curve -> Point curve ECC.keypairGetPublic (KeyPair Curve -> Point) -> IO (KeyPair Curve) -> IO Point forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Proxy Curve -> IO (KeyPair Curve) forall curve (randomly :: * -> *) (proxy :: * -> *). (EllipticCurve curve, MonadRandom randomly) => proxy curve -> randomly (KeyPair curve) forall (randomly :: * -> *) (proxy :: * -> *). MonadRandom randomly => proxy Curve -> randomly (KeyPair Curve) ECC.curveGenerateKeyPair Proxy Curve proxy) {-# NOINLINE pointBinarySize #-} throwOnFail :: MonadThrow m => CE.CryptoFailable a -> m a throwOnFail :: forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a throwOnFail (CE.CryptoPassed a a) = a -> m a forall a. a -> m a forall (f :: * -> *) a. Applicative f => a -> f a pure a a throwOnFail (CE.CryptoFailed CryptoError e) = CryptoError -> m a forall e a. (HasCallStack, Exception e) => e -> m a forall (m :: * -> *) e a. (MonadThrow m, HasCallStack, Exception e) => e -> m a throwM CryptoError e encrypt :: (MonadThrow m, MonadRandom m) => ECC.Point Curve -> ConduitM ByteString ByteString m () encrypt :: forall (m :: * -> *). (MonadThrow m, MonadRandom m) => Point Curve -> ConduitM ByteString ByteString m () encrypt Point Curve point = do (point', shared) <- m (CryptoFailable (Point, SharedSecret)) -> ConduitT ByteString ByteString m (CryptoFailable (Point, SharedSecret)) forall (m :: * -> *) a. Monad m => m a -> ConduitT ByteString ByteString m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (Proxy Curve -> Point Curve -> m (CryptoFailable (Point Curve, SharedSecret)) forall {randomly :: * -> *} {curve} {proxy :: * -> *}. (MonadRandom randomly, EllipticCurveDH curve) => proxy curve -> Point curve -> randomly (CryptoFailable (Point curve, SharedSecret)) deriveEncryptCompat Proxy Curve proxy Point Curve point) ConduitT ByteString ByteString m (CryptoFailable (Point, SharedSecret)) -> (CryptoFailable (Point, SharedSecret) -> ConduitT ByteString ByteString m (Point, SharedSecret)) -> ConduitT ByteString ByteString m (Point, SharedSecret) forall a b. ConduitT ByteString ByteString m a -> (a -> ConduitT ByteString ByteString m b) -> ConduitT ByteString ByteString m b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= CryptoFailable (Point, SharedSecret) -> ConduitT ByteString ByteString m (Point, SharedSecret) forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a throwOnFail let (nonce, key) = getNonceKey shared yield $ ECC.encodePoint proxy point' ChaCha.encrypt nonce key where #if MIN_VERSION_cryptonite(0,23,999) deriveEncryptCompat :: proxy curve -> Point curve -> randomly (CryptoFailable (Point curve, SharedSecret)) deriveEncryptCompat proxy curve prx Point curve p = proxy curve -> Point curve -> randomly (CryptoFailable (Point curve, SharedSecret)) forall {randomly :: * -> *} {curve} {proxy :: * -> *}. (MonadRandom randomly, EllipticCurveDH curve) => proxy curve -> Point curve -> randomly (CryptoFailable (Point curve, SharedSecret)) deriveEncrypt proxy curve prx Point curve p #else deriveEncryptCompat prx p = CE.CryptoPassed <$> deriveEncrypt prx p #endif decrypt :: (MonadThrow m) => ECC.Scalar Curve -> ConduitM ByteString ByteString m () decrypt :: forall (m :: * -> *). MonadThrow m => Scalar Curve -> ConduitM ByteString ByteString m () decrypt Scalar Curve scalar = do pointBS <- (LazyByteString -> ByteString) -> ConduitT ByteString ByteString m LazyByteString -> ConduitT ByteString ByteString m ByteString forall a b. (a -> b) -> ConduitT ByteString ByteString m a -> ConduitT ByteString ByteString m b forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b fmap LazyByteString -> ByteString BL.toStrict (ConduitT ByteString ByteString m LazyByteString -> ConduitT ByteString ByteString m ByteString) -> ConduitT ByteString ByteString m LazyByteString -> ConduitT ByteString ByteString m ByteString forall a b. (a -> b) -> a -> b $ Int -> ConduitT ByteString ByteString m LazyByteString forall (m :: * -> *) o. Monad m => Int -> ConduitT ByteString o m LazyByteString CB.take Int pointBinarySize point <- throwOnFail (ECC.decodePoint proxy pointBS) shared <- throwOnFail (deriveDecryptCompat proxy point scalar) let (_nonce, key) = getNonceKey shared ChaCha.decrypt key where #if MIN_VERSION_cryptonite(0,23,999) deriveDecryptCompat :: proxy curve -> Point curve -> Scalar curve -> CryptoFailable SharedSecret deriveDecryptCompat proxy curve prx Point curve p Scalar curve s = proxy curve -> Point curve -> Scalar curve -> CryptoFailable SharedSecret forall {curve} {proxy :: * -> *}. EllipticCurveDH curve => proxy curve -> Point curve -> Scalar curve -> CryptoFailable SharedSecret deriveDecrypt proxy curve prx Point curve p Scalar curve s #else deriveDecryptCompat prx p s = CE.CryptoPassed (deriveDecrypt prx p s) #endif