Skip to content

Commit

Permalink
Preallocate the chunk size rather than buffering
Browse files Browse the repository at this point in the history
Since the chunk size is capped at 4MB now, we can safely preallocate
it so that we don't have to buffer each chunk.
  • Loading branch information
twiss committed Dec 16, 2024
1 parent add07bd commit 6fa7f91
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 41 deletions.
54 changes: 17 additions & 37 deletions openpgp/packet/aead_crypter.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ func (wo *aeadCrypter) incrementIndex() error {
type aeadDecrypter struct {
aeadCrypter // Embedded ciphertext opener
reader io.Reader // 'reader' is a partialLengthReader
chunkBytes []byte
peekedBytes []byte // Used to detect last chunk
eof bool
}

// Read decrypts bytes and reads them into dst. It decrypts when necessary and
Expand All @@ -75,22 +75,18 @@ func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) {
return ar.buffer.Read(dst)
}

// Return EOF if we've previously validated the final tag
if ar.eof {
return 0, io.EOF
}

// Read a chunk
tagLen := ar.aead.Overhead()
cipherChunkBuf := new(bytes.Buffer)
_, errRead := io.CopyN(cipherChunkBuf, ar.reader, int64(ar.chunkSize+tagLen))
cipherChunk := cipherChunkBuf.Bytes()
if errRead != nil && errRead != io.EOF {
copy(ar.chunkBytes, ar.peekedBytes) // Copy bytes peeked in previous chunk or in initialization
bytesRead, errRead := io.ReadFull(ar.reader, ar.chunkBytes[tagLen:])
if errRead != nil && errRead != io.EOF && errRead != io.ErrUnexpectedEOF {
return 0, errRead
}

if len(cipherChunk) > 0 {
decrypted, errChunk := ar.openChunk(cipherChunk)
if bytesRead > 0 {
ar.peekedBytes = ar.chunkBytes[bytesRead:bytesRead+tagLen]

decrypted, errChunk := ar.openChunk(ar.chunkBytes[:bytesRead])
if errChunk != nil {
return 0, errChunk
}
Expand All @@ -102,28 +98,19 @@ func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) {
} else {
n = copy(dst, decrypted)
}
return
}

// Check final authentication tag
if errRead == io.EOF {
errChunk := ar.validateFinalTag(ar.peekedBytes)
if errChunk != nil {
return n, errChunk
}
ar.eof = true // Mark EOF for when we've returned all buffered data
}
return
return 0, io.EOF
}

// Close is noOp. The final authentication tag of the stream was already
// checked in the last Read call. In the future, this function could be used to
// wipe the reader and peeked, decrypted bytes, if necessary.
// Close checks the final authentication tag of the stream.
// In the future, this function could also be used to wipe the reader
// and peeked & decrypted bytes, if necessary.
func (ar *aeadDecrypter) Close() (err error) {
if !ar.eof {
errChunk := ar.validateFinalTag(ar.peekedBytes)
if errChunk != nil {
return errChunk
}
errChunk := ar.validateFinalTag(ar.peekedBytes)
if errChunk != nil {
return errChunk
}
return nil
}
Expand All @@ -132,20 +119,13 @@ func (ar *aeadDecrypter) Close() (err error) {
// the underlying plaintext and an error. It accesses peeked bytes from next
// chunk, to identify the last chunk and decrypt/validate accordingly.
func (ar *aeadDecrypter) openChunk(data []byte) ([]byte, error) {
tagLen := ar.aead.Overhead()
// Restore carried bytes from last call
chunkExtra := append(ar.peekedBytes, data...)
// 'chunk' contains encrypted bytes, followed by an authentication tag.
chunk := chunkExtra[:len(chunkExtra)-tagLen]
ar.peekedBytes = chunkExtra[len(chunkExtra)-tagLen:]

adata := ar.associatedData
if ar.aeadCrypter.packetTag == packetTypeAEADEncrypted {
adata = append(ar.associatedData, ar.chunkIndex...)
}

nonce := ar.computeNextNonce()
plainChunk, err := ar.aead.Open(nil, nonce, chunk, adata)
plainChunk, err := ar.aead.Open(nil, nonce, data, adata)
if err != nil {
return nil, errors.ErrAEADTagVerification
}
Expand Down
10 changes: 7 additions & 3 deletions openpgp/packet/aead_encrypted.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ func (ae *AEADEncrypted) decrypt(key []byte) (io.ReadCloser, error) {
blockCipher := ae.cipher.new(key)
aead := ae.mode.new(blockCipher)
// Carry the first tagLen bytes
chunkSize := decodeAEADChunkSize(ae.chunkSizeByte)
tagLen := ae.mode.TagLength()
peekedBytes := make([]byte, tagLen)
chunkBytes := make([]byte, chunkSize+tagLen*2)
peekedBytes := chunkBytes[chunkSize+tagLen:]
n, err := io.ReadFull(ae.Contents, peekedBytes)
if n < tagLen || (err != nil && err != io.EOF) {
return nil, errors.AEADError("Not enough data to decrypt:" + err.Error())
}
chunkSize := decodeAEADChunkSize(ae.chunkSizeByte)

return &aeadDecrypter{
aeadCrypter: aeadCrypter{
aead: aead,
Expand All @@ -82,7 +84,9 @@ func (ae *AEADEncrypted) decrypt(key []byte) (io.ReadCloser, error) {
packetTag: packetTypeAEADEncrypted,
},
reader: ae.Contents,
peekedBytes: peekedBytes}, nil
chunkBytes: chunkBytes,
peekedBytes: peekedBytes,
}, nil
}

// associatedData for chunks: tag, version, cipher, mode, chunk size byte
Expand Down
1 change: 1 addition & 0 deletions openpgp/packet/aead_encrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ func readDecryptedStream(rc io.ReadCloser) (got []byte, err error) {
}
}
}
err = rc.Close()
return got, err
}

Expand Down
5 changes: 4 additions & 1 deletion openpgp/packet/symmetrically_encrypted_aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, e

aead, nonce := getSymmetricallyEncryptedAeadInstance(se.Cipher, se.Mode, inputKey, se.Salt[:], se.associatedData())
// Carry the first tagLen bytes
chunkSize := decodeAEADChunkSize(se.ChunkSizeByte)
tagLen := se.Mode.TagLength()
peekedBytes := make([]byte, tagLen)
chunkBytes := make([]byte, chunkSize+tagLen*2)
peekedBytes := chunkBytes[chunkSize+tagLen:]
n, err := io.ReadFull(se.Contents, peekedBytes)
if n < tagLen || (err != nil && err != io.EOF) {
return nil, errors.StructuralError("not enough data to decrypt:" + err.Error())
Expand All @@ -87,6 +89,7 @@ func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, e
packetTag: packetTypeSymmetricallyEncryptedIntegrityProtected,
},
reader: se.Contents,
chunkBytes: chunkBytes,
peekedBytes: peekedBytes,
}, nil
}
Expand Down

0 comments on commit 6fa7f91

Please sign in to comment.