diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index 7fe8bd05ab..191f062204 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -20,7 +20,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging -import scodec.Attempt +import scodec.{Attempt, DecodeResult} import scodec.bits.ByteVector import scala.annotation.tailrec @@ -347,10 +347,10 @@ object Sphinx extends Logging { def create(sharedSecret: ByteVector32, failure: FailureMessage, holdTime: FiniteDuration): ByteVector = { val failurePayload = FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector - val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.fill(hopPayloadLength)(0)) - val zeroHmacs = (maxNumHop.to(1, -1)).map(Seq.fill(_)(ByteVector.low(4))) + val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.low(hopPayloadLength)) + val zeroHmacs = maxNumHop.to(1, -1).map(Seq.fill(_)(ByteVector.low(4))) val plainError = attributableErrorCodec(totalLength, hopPayloadLength, maxNumHop).encode(AttributableError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes - wrap(plainError, sharedSecret, holdTime, isSource = true).get + wrap(plainError, sharedSecret, holdTime, isSource = true) } private def computeHmacs(mac: Mac32, failurePayload: ByteVector, hopPayloads: Seq[ByteVector], hmacs: Seq[Seq[ByteVector]], minNumHop: Int): Seq[ByteVector] = { @@ -363,9 +363,12 @@ object Sphinx extends Logging { newHmacs } - def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration, isSource: Boolean): Try[ByteVector] = Try { + def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration, isSource: Boolean): ByteVector = { val um = generateKey("um", sharedSecret) - val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits).require.value + val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits) match { + case Attempt.Successful(DecodeResult(value, _)) => value + case Attempt.Failure(_) => AttributableError.zero(payloadAndPadLength, hopPayloadLength, maxNumHop) + } val hopPayloads = hopPayloadCodec.encode(HopPayload(isSource, holdTime)).require.bytes +: error.hopPayloads.dropRight(1) val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs.map(_.drop(1)), 0) +: error.hmacs.dropRight(1).map(_.drop(1)) val newError = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(AttributableError(error.failurePayload, hopPayloads, hmacs)).require.bytes @@ -374,14 +377,6 @@ object Sphinx extends Logging { newError xor stream } - def wrapOrCreate(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration): ByteVector = - wrap(errorPacket, sharedSecret, holdTime, isSource = false) match { - case Failure(_) => - // There is no failure message for this use-case, using TemporaryNodeFailure instead. - create(sharedSecret, TemporaryNodeFailure(), holdTime) - case Success(value) => value - } - private def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try { val key = generateKey("ammag", sharedSecret) val stream = generateStream(key, errorPacket.length.toInt) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index da0395a3e2..28cb41364b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -314,7 +314,7 @@ object OutgoingPaymentPacket { Sphinx.peel(nodeSecret, Some(add.paymentHash), add.onionRoutingPacket) match { case Right(Sphinx.DecryptedPacket(_, _, sharedSecret)) => val encryptedReason = reason match { - case Left(forwarded) if useAttributableErrors => Sphinx.AttributableErrorPacket.wrapOrCreate(forwarded, sharedSecret, holdTime) + case Left(forwarded) if useAttributableErrors => Sphinx.AttributableErrorPacket.wrap(forwarded, sharedSecret, holdTime, isSource = false) case Right(failure) if useAttributableErrors => Sphinx.AttributableErrorPacket.create(sharedSecret, failure, holdTime) case Left(forwarded) => Sphinx.FailurePacket.wrap(forwarded, sharedSecret) case Right(failure) => Sphinx.FailurePacket.create(sharedSecret, failure) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/AttributableError.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/AttributableError.scala index 785d14d094..79c8001336 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/AttributableError.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/AttributableError.scala @@ -48,4 +48,10 @@ object AttributableError { (("failure_payload" | bytes(totalLength - metadataLength)) :: ("hop_payloads" | listOfN(provide(maxNumHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) :: ("hmacs" | hmacsCodec(maxNumHop))).as[AttributableError].complete} + + def zero(payloadAndPadLength: Int, hopPayloadLength: Int, maxNumHop: Int): AttributableError = + AttributableError( + ByteVector.low(payloadAndPadLength), + Seq.fill(maxNumHop)(ByteVector.low(hopPayloadLength)), + maxNumHop.to(1, -1).map(Seq.fill(_)(ByteVector.low(4)))) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 2560821fdf..03d96ac657 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -414,13 +414,13 @@ class SphinxSpec extends AnyFunSuite { val Right(decrypted1) = AttributableErrorPacket.decrypt(packet1, (2 to 4).map(i => (sharedSecrets(i), publicKeys(i)))) assert(decrypted1 == expected) - val Success(packet2) = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 5 millis, isSource = false) + val packet2 = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 5 millis, isSource = false) assert(packet2.length == 1200) val Right(decrypted2) = AttributableErrorPacket.decrypt(packet2, (1 to 4).map(i => (sharedSecrets(i), publicKeys(i)))) assert(decrypted2 == expected) - val Success(packet3) = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 9 millis, isSource = false) + val packet3 = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 9 millis, isSource = false) assert(packet3.length == 1200) val Right(decrypted3) = AttributableErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i)))) @@ -440,11 +440,11 @@ class SphinxSpec extends AnyFunSuite { val packet1 = randomBytes(1200) val hopPayload2 = AttributableError.HopPayload(isPayloadSource = false, 50 millis) - val Success(packet2) = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 50 millis, isSource = false) + val packet2 = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 50 millis, isSource = false) assert(packet2.length == 1200) val hopPayload3 = AttributableError.HopPayload(isPayloadSource = false, 100 millis) - val Success(packet3) = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 100 millis, isSource = false) + val packet3 = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 100 millis, isSource = false) assert(packet3.length == 1200) val Left(decryptionError) = AttributableErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i))))