{-# LANGUAGE DeriveDataTypeable #-}
module Crypto.Cipher.ChaChaPoly1305.Conduit
( encrypt
, decrypt
, ChaChaException (..)
) where
import Control.Exception (assert)
import Control.Monad.Catch (Exception, MonadThrow, throwM)
import qualified Crypto.Cipher.ChaChaPoly1305 as Cha
import qualified Crypto.Error as CE
import qualified Crypto.MAC.Poly1305 as Poly1305
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Conduit (ConduitM, await, leftover, yield)
import qualified Data.Conduit.Binary as CB
import Data.Typeable (Typeable)
cf :: MonadThrow m
=> (CE.CryptoError -> ChaChaException)
-> CE.CryptoFailable a
-> m a
cf :: forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
_ (CE.CryptoPassed a
x) = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
cf CryptoError -> ChaChaException
f (CE.CryptoFailed CryptoError
e) = ChaChaException -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (CryptoError -> ChaChaException
f CryptoError
e)
data ChaChaException
= EncryptNonceException !CE.CryptoError
| EncryptKeyException !CE.CryptoError
| DecryptNonceException !CE.CryptoError
| DecryptKeyException !CE.CryptoError
| MismatchedAuth
deriving (Int -> ChaChaException -> ShowS
[ChaChaException] -> ShowS
ChaChaException -> String
(Int -> ChaChaException -> ShowS)
-> (ChaChaException -> String)
-> ([ChaChaException] -> ShowS)
-> Show ChaChaException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ChaChaException -> ShowS
showsPrec :: Int -> ChaChaException -> ShowS
$cshow :: ChaChaException -> String
show :: ChaChaException -> String
$cshowList :: [ChaChaException] -> ShowS
showList :: [ChaChaException] -> ShowS
Show, Typeable)
instance Exception ChaChaException
encrypt
:: MonadThrow m
=> ByteString
-> ByteString
-> ConduitM ByteString ByteString m ()
encrypt :: forall (m :: * -> *).
MonadThrow m =>
ByteString -> ByteString -> ConduitM ByteString ByteString m ()
encrypt ByteString
nonceBS ByteString
key = do
nonce <- (CryptoError -> ChaChaException)
-> CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce
forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
EncryptNonceException (CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce)
-> CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
Cha.nonce12 ByteString
nonceBS
state0 <- cf EncryptKeyException $ Cha.initialize key nonce
yield nonceBS
let loop State
state1 = do
mbs <- ConduitT o o m (Maybe o)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
case mbs of
Maybe o
Nothing -> o -> ConduitT o o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (o -> ConduitT o o m ()) -> o -> ConduitT o o m ()
forall a b. (a -> b) -> a -> b
$ Auth -> o
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Auth -> o) -> Auth -> o
forall a b. (a -> b) -> a -> b
$ State -> Auth
Cha.finalize State
state1
Just o
bs -> do
let (o
bs', State
state2) = o -> State -> (o, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
Cha.encrypt o
bs State
state1
o -> ConduitT o o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield o
bs'
State -> ConduitT o o m ()
loop State
state2
loop $ Cha.finalizeAAD state0
decrypt
:: MonadThrow m
=> ByteString
-> ConduitM ByteString ByteString m ()
decrypt :: forall (m :: * -> *).
MonadThrow m =>
ByteString -> ConduitM ByteString ByteString m ()
decrypt ByteString
key = do
nonceBS <- Int -> ConduitT ByteString ByteString m ByteString
forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take Int
12
nonce <- cf DecryptNonceException $ Cha.nonce12 $ BL.toStrict nonceBS
state0 <- cf DecryptKeyException $ Cha.initialize key nonce
let loop State
state1 = do
ebs <- (ByteString -> ByteString)
-> ConduitT ByteString ByteString m (Either ByteString ByteString)
forall {m :: * -> *} {o}.
Monad m =>
(ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 ByteString -> ByteString
forall a. a -> a
id
case ebs of
Left ByteString
final ->
case ByteString -> CryptoFailable Auth
forall b. ByteArrayAccess b => b -> CryptoFailable Auth
Poly1305.authTag ByteString
final of
CE.CryptoPassed Auth
final' | State -> Auth
Cha.finalize State
state1 Auth -> Auth -> Bool
forall a. Eq a => a -> a -> Bool
== Auth
final' -> () -> ConduitT ByteString ByteString m ()
forall a. a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
CryptoFailable Auth
_ -> ChaChaException -> ConduitT ByteString ByteString m ()
forall e a.
(HasCallStack, Exception e) =>
e -> ConduitT ByteString ByteString m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM ChaChaException
MismatchedAuth
Right ByteString
bs -> do
let (ByteString
bs', State
state2) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
Cha.decrypt ByteString
bs State
state1
ByteString -> ConduitT ByteString ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs'
State -> ConduitT ByteString ByteString m ()
loop State
state2
loop $ Cha.finalizeAAD state0
where
awaitExcept16 :: (ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 ByteString -> ByteString
front = do
mbs <- ConduitT ByteString o m (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
case mbs of
Maybe ByteString
Nothing -> Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall a. a -> ConduitT ByteString o m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString))
-> Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either ByteString ByteString
forall a b. a -> Either a b
Left (ByteString -> Either ByteString ByteString)
-> ByteString -> Either ByteString ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
front ByteString
B.empty
Just ByteString
bs -> do
let bs' :: ByteString
bs' = ByteString -> ByteString
front ByteString
bs
if ByteString -> Int
B.length ByteString
bs' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
16
then do
let (ByteString
x, ByteString
y) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length ByteString
bs' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
bs'
Bool
-> (ByteString -> ConduitT ByteString o m ())
-> ByteString
-> ConduitT ByteString o m ()
forall a. HasCallStack => Bool -> a -> a
assert (ByteString -> Int
B.length ByteString
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
16) ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
y
Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall a. a -> ConduitT ByteString o m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString))
-> Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either ByteString ByteString
forall a b. b -> Either a b
Right ByteString
x
else (ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 (ByteString -> ByteString -> ByteString
B.append ByteString
bs')