Skip to content
This repository has been archived by the owner on Apr 12, 2022. It is now read-only.

Commit

Permalink
Merge pull request #422 from matrix-org/feature/create_rec_progress
Browse files Browse the repository at this point in the history
Extract ProgressListener.
  • Loading branch information
bmarty authored Jan 31, 2019
2 parents 9d38613 + 3b26584 commit 7fd0f2e
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package org.matrix.androidsdk.crypto.keysbackup

import android.support.test.runner.AndroidJUnit4
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Assert.*
import org.junit.Before
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
import org.matrix.androidsdk.common.assertByteArrayNotEqual
import org.matrix.androidsdk.listeners.ProgressListener
import org.matrix.olm.OlmManager
import org.matrix.olm.OlmPkDecryption

Expand All @@ -42,7 +42,7 @@ class KeysBackupPasswordTest {
*/
@Test
fun passwordConverter_ok() {
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD)
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD, null)

assertEquals(32, generatePrivateKeyResult.salt.length)
assertEquals(500_000, generatePrivateKeyResult.iterations)
Expand All @@ -57,12 +57,40 @@ class KeysBackupPasswordTest {
assertArrayEquals(generatePrivateKeyResult.privateKey, retrievedPrivateKey)
}

/**
* Check generatePrivateKeyWithPassword progress listener behavior
*/
@Test
fun passwordConverter_progress_ok() {
val progressValues = ArrayList<Int>(101)
var lastTotal = 0

generatePrivateKeyWithPassword(PASSWORD, object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
if (!progressValues.contains(progress)) {
progressValues.add(progress)
}

lastTotal = total
}
})

assertEquals(100, lastTotal)

// Ensure all values are here
assertEquals(101, progressValues.size)

for (i in 0..100) {
assertTrue(progressValues[i] == i)
}
}

/**
* Check KeysBackupPassword utilities, with bad password
*/
@Test
fun passwordConverter_badPassword_ok() {
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD)
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD, null)

assertEquals(32, generatePrivateKeyResult.salt.length)
assertEquals(500_000, generatePrivateKeyResult.iterations)
Expand All @@ -82,7 +110,7 @@ class KeysBackupPasswordTest {
*/
@Test
fun passwordConverter_badIteration_ok() {
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD)
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD, null)

assertEquals(32, generatePrivateKeyResult.salt.length)
assertEquals(500_000, generatePrivateKeyResult.iterations)
Expand All @@ -102,7 +130,7 @@ class KeysBackupPasswordTest {
*/
@Test
fun passwordConverter_badSalt_ok() {
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD)
val generatePrivateKeyResult = generatePrivateKeyWithPassword(PASSWORD, null)

assertEquals(32, generatePrivateKeyResult.salt.length)
assertEquals(500_000, generatePrivateKeyResult.iterations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.matrix.androidsdk.crypto.MegolmSessionData
import org.matrix.androidsdk.crypto.data.ImportRoomKeysResult
import org.matrix.androidsdk.crypto.data.MXDeviceInfo
import org.matrix.androidsdk.crypto.data.MXOlmInboundGroupSession2
import org.matrix.androidsdk.listeners.ProgressListener
import org.matrix.androidsdk.rest.callback.SuccessCallback
import org.matrix.androidsdk.rest.callback.SuccessErrorCallback
import org.matrix.androidsdk.rest.model.keys.CreateKeysBackupVersionBody
Expand Down Expand Up @@ -113,7 +114,7 @@ class KeysBackupTest {

val latch = CountDownLatch(1)

keysBackup.prepareKeysBackupVersion(null, object : SuccessErrorCallback<MegolmBackupCreationInfo> {
keysBackup.prepareKeysBackupVersion(null, null, object : SuccessErrorCallback<MegolmBackupCreationInfo> {
override fun onSuccess(info: MegolmBackupCreationInfo?) {
assertNotNull(info)

Expand Down Expand Up @@ -151,7 +152,7 @@ class KeysBackupTest {

var megolmBackupCreationInfo: MegolmBackupCreationInfo? = null
val latch = CountDownLatch(1)
keysBackup.prepareKeysBackupVersion(null, object : SuccessErrorCallback<MegolmBackupCreationInfo> {
keysBackup.prepareKeysBackupVersion(null, null, object : SuccessErrorCallback<MegolmBackupCreationInfo> {

override fun onSuccess(info: MegolmBackupCreationInfo) {
megolmBackupCreationInfo = info
Expand Down Expand Up @@ -269,10 +270,10 @@ class KeysBackupTest {

var lastBackedUpKeysProgress = 0

keysBackup.backupAllGroupSessions(object : KeysBackup.BackupProgressListener {
override fun onProgress(backedUp: Int, total: Int) {
keysBackup.backupAllGroupSessions(object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
assertEquals(nbOfKeys, total)
lastBackedUpKeysProgress = backedUp
lastBackedUpKeysProgress = progress
}

}, TestApiCallback(latch))
Expand Down Expand Up @@ -697,8 +698,8 @@ class KeysBackupTest {

// - Make alice back up all her keys again
val latch2 = CountDownLatch(1)
keysBackup.backupAllGroupSessions(object : KeysBackup.BackupProgressListener {
override fun onProgress(backedUp: Int, total: Int) {
keysBackup.backupAllGroupSessions(object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
}

}, TestApiCallback(latch2, false))
Expand Down Expand Up @@ -736,8 +737,8 @@ class KeysBackupTest {

// Wait for keys backup to finish by asking again to backup keys.
val latch = CountDownLatch(1)
keysBackup.backupAllGroupSessions(object : KeysBackup.BackupProgressListener {
override fun onProgress(backedUp: Int, total: Int) {
keysBackup.backupAllGroupSessions(object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {

}
}, TestApiCallback(latch))
Expand Down Expand Up @@ -765,8 +766,8 @@ class KeysBackupTest {

var isSuccessful = false
val latch2 = CountDownLatch(1)
keysBackup2.backupAllGroupSessions(object : KeysBackup.BackupProgressListener {
override fun onProgress(backedUp: Int, total: Int) {
keysBackup2.backupAllGroupSessions(object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
}

}, object : TestApiCallback<Void?>(latch2, false) {
Expand Down Expand Up @@ -828,7 +829,7 @@ class KeysBackupTest {
password: String? = null): PrepareKeysBackupDataResult {
var megolmBackupCreationInfo: MegolmBackupCreationInfo? = null
val latch = CountDownLatch(1)
keysBackup.prepareKeysBackupVersion(password, object : SuccessErrorCallback<MegolmBackupCreationInfo> {
keysBackup.prepareKeysBackupVersion(password, null, object : SuccessErrorCallback<MegolmBackupCreationInfo> {

override fun onSuccess(info: MegolmBackupCreationInfo) {
megolmBackupCreationInfo = info
Expand Down Expand Up @@ -913,17 +914,17 @@ class KeysBackupTest {
val prepareKeysBackupDataResult = prepareAndCreateKeysBackupData(keysBackup, password)

val latch = CountDownLatch(1)
var lastBackup = 0
var lastProgress = 0
var lastTotal = 0
keysBackup.backupAllGroupSessions(object : KeysBackup.BackupProgressListener {
override fun onProgress(backedUp: Int, total: Int) {
lastBackup = backedUp
keysBackup.backupAllGroupSessions(object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
lastProgress = progress
lastTotal = total
}
}, TestApiCallback(latch))
mTestHelper.await(latch)

assertEquals(2, lastBackup)
assertEquals(2, lastProgress)
assertEquals(2, lastTotal)

// - Log Alice on a new device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.matrix.androidsdk.crypto.data.MXDeviceInfo
import org.matrix.androidsdk.crypto.data.MXOlmInboundGroupSession2
import org.matrix.androidsdk.crypto.util.computeRecoveryKey
import org.matrix.androidsdk.crypto.util.extractCurveKeyFromRecoveryKey
import org.matrix.androidsdk.listeners.ProgressListener
import org.matrix.androidsdk.rest.callback.ApiCallback
import org.matrix.androidsdk.rest.callback.SimpleApiCallback
import org.matrix.androidsdk.rest.callback.SuccessCallback
Expand Down Expand Up @@ -93,9 +94,11 @@ class KeysBackup(private val mCrypto: MXCrypto, session: MXSession) {
*
* @param password an optional passphrase string that can be entered by the user
* when restoring the backup as an alternative to entering the recovery key.
* @param progressListener a progress listener, as generating private key from password may take a while
* @param callback Asynchronous callback
*/
fun prepareKeysBackupVersion(password: String?,
progressListener: ProgressListener?,
callback: SuccessErrorCallback<MegolmBackupCreationInfo>) {
mCrypto.decryptingThreadHandler.post {
try {
Expand All @@ -104,7 +107,23 @@ class KeysBackup(private val mCrypto: MXCrypto, session: MXSession) {

if (password != null) {
// Generate a private key from the password
val generatePrivateKeyResult = generatePrivateKeyWithPassword(password)
val backgroundProgressListener = if (progressListener == null) {
null
} else {
object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
mCrypto.uiHandler.post {
try {
progressListener.onProgress(progress, total)
} catch (e: Exception) {
Log.e(LOG_TAG, "prepareKeysBackupVersion: onProgress failure", e)
}
}
}
}
}

val generatePrivateKeyResult = generatePrivateKeyWithPassword(password, backgroundProgressListener)
megolmBackupAuthData.publicKey = olmPkDecryption.setPrivateKey(generatePrivateKeyResult.privateKey)
megolmBackupAuthData.privateKeySalt = generatePrivateKeyResult.salt
megolmBackupAuthData.privateKeyIterations = generatePrivateKeyResult.iterations
Expand Down Expand Up @@ -203,17 +222,21 @@ class KeysBackup(private val mCrypto: MXCrypto, session: MXSession) {
* @param progress the callback to follow the progress
* @param callback the main callback
*/
fun backupAllGroupSessions(progress: BackupProgressListener?,
fun backupAllGroupSessions(progressListener: ProgressListener?,
callback: ApiCallback<Void?>?) {
// Get a status right now
getBackupProgress(object : BackupProgressListener {
override fun onProgress(backedUp: Int, total: Int) {
getBackupProgress(object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
// Reset previous listeners if any
resetBackupAllGroupSessionsListeners()
Log.d(LOG_TAG, "backupAllGroupSessions: backupProgress: $backedUp/$total")
progress?.onProgress(backedUp, total)
Log.d(LOG_TAG, "backupAllGroupSessions: backupProgress: $progress/$total")
try {
progressListener?.onProgress(progress, total)
} catch (e: Exception) {
Log.e(LOG_TAG, "backupAllGroupSessions: onProgress failure", e)
}

if (backedUp == total) {
if (progress == total) {
Log.d(LOG_TAG, "backupAllGroupSessions: complete")
callback?.onSuccess(null)
return
Expand All @@ -224,9 +247,13 @@ class KeysBackup(private val mCrypto: MXCrypto, session: MXSession) {
// Listen to `state` change to determine when to call onBackupProgress and onComplete
mKeysBackupStateListener = object : KeysBackupStateManager.KeysBackupStateListener {
override fun onStateChange(newState: KeysBackupStateManager.KeysBackupState) {
getBackupProgress(object : BackupProgressListener {
override fun onProgress(backedUp: Int, total: Int) {
progress?.onProgress(backedUp, total)
getBackupProgress(object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
try {
progressListener?.onProgress(progress, total)
} catch (e: Exception) {
Log.e(LOG_TAG, "backupAllGroupSessions: onProgress failure 2", e)
}

// If backup is finished, notify the main listener
if (state === KeysBackupStateManager.KeysBackupState.ReadyToBackUp) {
Expand Down Expand Up @@ -328,16 +355,12 @@ class KeysBackup(private val mCrypto: MXCrypto, session: MXSession) {
mKeysBackupStateListener = null
}

interface BackupProgressListener {
fun onProgress(backedUp: Int, total: Int)
}

private fun getBackupProgress(listener: BackupProgressListener) {
private fun getBackupProgress(progressListener: ProgressListener) {
mCrypto.decryptingThreadHandler.post {
val backedUpKeys = mCrypto.cryptoStore.inboundGroupSessionsCount(true)
val total = mCrypto.cryptoStore.inboundGroupSessionsCount(false)

mCrypto.uiHandler.post { listener.onProgress(backedUpKeys, total) }
mCrypto.uiHandler.post { progressListener.onProgress(backedUpKeys, total) }
}
}

Expand Down Expand Up @@ -447,7 +470,7 @@ class KeysBackup(private val mCrypto: MXCrypto, session: MXSession) {
return
}

// This is the recovery key
// Compute the recovery key
val privateKey = retrievePrivateKeyWithPassword(password,
megolmBackupAuthData.privateKeySalt!!,
megolmBackupAuthData.privateKeyIterations!!)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
package org.matrix.androidsdk.crypto.keysbackup

import org.matrix.androidsdk.listeners.ProgressListener
import org.matrix.androidsdk.util.Log
import java.util.*
import javax.crypto.Mac
Expand All @@ -44,10 +45,10 @@ data class GeneratePrivateKeyResult(
*
* @return a {privateKey, salt, iterations} tuple.
*/
fun generatePrivateKeyWithPassword(password: String): GeneratePrivateKeyResult {
fun generatePrivateKeyWithPassword(password: String, progressListener: ProgressListener?): GeneratePrivateKeyResult {
val salt = generateSalt()
val iterations = DEFAULT_ITERATION
val privateKey = deriveKey(password, salt, iterations)
val privateKey = deriveKey(password, salt, iterations, progressListener)

return GeneratePrivateKeyResult(privateKey, salt, iterations)
}
Expand All @@ -64,20 +65,22 @@ fun generatePrivateKeyWithPassword(password: String): GeneratePrivateKeyResult {
fun retrievePrivateKeyWithPassword(password: String,
salt: String,
iterations: Int): ByteArray {
return deriveKey(password, salt, iterations)
return deriveKey(password, salt, iterations, null)
}

/**
* Compute a private key by deriving a password and a salt strings.
* @param password the password.
* @param salt the salt.
* @param iterations number of derivations.
* @param progressListener a listener to follow progress.
*
* @return a private key.
*/
private fun deriveKey(password: String,
salt: String,
iterations: Int): ByteArray {
iterations: Int,
progressListener: ProgressListener?): ByteArray {
// Note: copied and adapted from MXMegolmExportEncryption
val t0 = System.currentTimeMillis()

Expand All @@ -104,6 +107,8 @@ private fun deriveKey(password: String,
// copy to the key
System.arraycopy(uc, 0, dk, 0, dk.size)

var lastProgress = -1

for (index in 2..iterations) {
// Uc = PRF(Password, Uc-1)
prf.update(uc)
Expand All @@ -113,6 +118,12 @@ private fun deriveKey(password: String,
for (byteIndex in dk.indices) {
dk[byteIndex] = dk[byteIndex] xor uc[byteIndex]
}

val progress = (index + 1) * 100 / iterations
if (progress != lastProgress) {
lastProgress = progress
progressListener?.onProgress(lastProgress, 100)
}
}

Log.d("KeysBackupPassword", "## deriveKeys() : " + iterations + " in " + (System.currentTimeMillis() - t0) + " ms")
Expand Down
Loading

0 comments on commit 7fd0f2e

Please sign in to comment.