mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-19 07:50:43 +01:00
339 lines
9.3 KiB
Go
339 lines
9.3 KiB
Go
package relay
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ecdh"
|
|
"testing"
|
|
)
|
|
|
|
func TestGenerateEphemeralKeyPair(t *testing.T) {
|
|
key1, err := GenerateEphemeralKeyPair()
|
|
if err != nil {
|
|
t.Fatalf("GenerateEphemeralKeyPair: %v", err)
|
|
}
|
|
if len(key1.PublicKey().Bytes()) != 32 {
|
|
t.Errorf("public key length: got %d, want 32", len(key1.PublicKey().Bytes()))
|
|
}
|
|
|
|
key2, err := GenerateEphemeralKeyPair()
|
|
if err != nil {
|
|
t.Fatalf("GenerateEphemeralKeyPair (second): %v", err)
|
|
}
|
|
if bytes.Equal(key1.PublicKey().Bytes(), key2.PublicKey().Bytes()) {
|
|
t.Error("two generated keys should be different")
|
|
}
|
|
}
|
|
|
|
func TestDeriveChannelKeys_Roundtrip(t *testing.T) {
|
|
instancePriv, err := GenerateEphemeralKeyPair()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
appPriv, err := GenerateEphemeralKeyPair()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
instanceEnc, err := DeriveChannelKeys(instancePriv, appPriv.PublicKey(), true)
|
|
if err != nil {
|
|
t.Fatalf("DeriveChannelKeys (instance): %v", err)
|
|
}
|
|
appEnc, err := DeriveChannelKeys(appPriv, instancePriv.PublicKey(), false)
|
|
if err != nil {
|
|
t.Fatalf("DeriveChannelKeys (app): %v", err)
|
|
}
|
|
|
|
// App encrypts, instance decrypts
|
|
plaintext := []byte("hello from the app")
|
|
ciphertext, err := appEnc.Encrypt(plaintext)
|
|
if err != nil {
|
|
t.Fatalf("app Encrypt: %v", err)
|
|
}
|
|
decrypted, err := instanceEnc.Decrypt(ciphertext)
|
|
if err != nil {
|
|
t.Fatalf("instance Decrypt: %v", err)
|
|
}
|
|
if !bytes.Equal(plaintext, decrypted) {
|
|
t.Errorf("roundtrip app→instance: got %q, want %q", decrypted, plaintext)
|
|
}
|
|
|
|
// Instance encrypts, app decrypts
|
|
plaintext2 := []byte("hello from the instance")
|
|
ciphertext2, err := instanceEnc.Encrypt(plaintext2)
|
|
if err != nil {
|
|
t.Fatalf("instance Encrypt: %v", err)
|
|
}
|
|
decrypted2, err := appEnc.Decrypt(ciphertext2)
|
|
if err != nil {
|
|
t.Fatalf("app Decrypt: %v", err)
|
|
}
|
|
if !bytes.Equal(plaintext2, decrypted2) {
|
|
t.Errorf("roundtrip instance→app: got %q, want %q", decrypted2, plaintext2)
|
|
}
|
|
}
|
|
|
|
func TestChannelEncryption_IncrementingNonce(t *testing.T) {
|
|
instancePriv, _ := GenerateEphemeralKeyPair()
|
|
appPriv, _ := GenerateEphemeralKeyPair()
|
|
|
|
enc, err := DeriveChannelKeys(instancePriv, appPriv.PublicKey(), true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
plaintext := []byte("same plaintext")
|
|
ct1, err := enc.Encrypt(plaintext)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ct2, err := enc.Encrypt(plaintext)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if bytes.Equal(ct1, ct2) {
|
|
t.Error("same plaintext should produce different ciphertext with incrementing nonces")
|
|
}
|
|
|
|
// Verify nonces are different (first 12 bytes)
|
|
if bytes.Equal(ct1[:nonceSize], ct2[:nonceSize]) {
|
|
t.Error("nonces should differ between calls")
|
|
}
|
|
}
|
|
|
|
func TestChannelEncryption_ReplayRejected(t *testing.T) {
|
|
instancePriv, _ := GenerateEphemeralKeyPair()
|
|
appPriv, _ := GenerateEphemeralKeyPair()
|
|
|
|
instanceEnc, _ := DeriveChannelKeys(instancePriv, appPriv.PublicKey(), true)
|
|
appEnc, _ := DeriveChannelKeys(appPriv, instancePriv.PublicKey(), false)
|
|
|
|
// App encrypts a message
|
|
ciphertext, err := appEnc.Encrypt([]byte("pay $100"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Instance decrypts it successfully the first time
|
|
plaintext, err := instanceEnc.Decrypt(ciphertext)
|
|
if err != nil {
|
|
t.Fatalf("first decrypt: %v", err)
|
|
}
|
|
if string(plaintext) != "pay $100" {
|
|
t.Fatalf("first decrypt: got %q", plaintext)
|
|
}
|
|
|
|
// Replaying the same frame must fail
|
|
_, err = instanceEnc.Decrypt(ciphertext)
|
|
if err == nil {
|
|
t.Fatal("expected error on replay, got nil")
|
|
}
|
|
if err != ErrNonceReplay {
|
|
t.Errorf("expected ErrNonceReplay, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestChannelEncryption_OutOfOrderNonceRejected(t *testing.T) {
|
|
instancePriv, _ := GenerateEphemeralKeyPair()
|
|
appPriv, _ := GenerateEphemeralKeyPair()
|
|
|
|
instanceEnc, _ := DeriveChannelKeys(instancePriv, appPriv.PublicKey(), true)
|
|
appEnc, _ := DeriveChannelKeys(appPriv, instancePriv.PublicKey(), false)
|
|
|
|
// App sends three messages
|
|
ct0, _ := appEnc.Encrypt([]byte("msg 0"))
|
|
ct1, _ := appEnc.Encrypt([]byte("msg 1"))
|
|
ct2, _ := appEnc.Encrypt([]byte("msg 2"))
|
|
|
|
// Instance receives them in order: 0, then 2 (skipping 1)
|
|
_, err := instanceEnc.Decrypt(ct0)
|
|
if err != nil {
|
|
t.Fatalf("decrypt ct0: %v", err)
|
|
}
|
|
|
|
// Skip ct1, decrypt ct2 — nonce 2 > expected 1, should succeed
|
|
_, err = instanceEnc.Decrypt(ct2)
|
|
if err != nil {
|
|
t.Fatalf("decrypt ct2 (skip): %v", err)
|
|
}
|
|
|
|
// Now try ct1 — nonce 1 < expected 3, must fail
|
|
_, err = instanceEnc.Decrypt(ct1)
|
|
if err == nil {
|
|
t.Fatal("expected error on out-of-order nonce, got nil")
|
|
}
|
|
if err != ErrNonceReplay {
|
|
t.Errorf("expected ErrNonceReplay, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestChannelEncryption_TamperedCiphertext(t *testing.T) {
|
|
instancePriv, _ := GenerateEphemeralKeyPair()
|
|
appPriv, _ := GenerateEphemeralKeyPair()
|
|
|
|
instanceEnc, _ := DeriveChannelKeys(instancePriv, appPriv.PublicKey(), true)
|
|
appEnc, _ := DeriveChannelKeys(appPriv, instancePriv.PublicKey(), false)
|
|
|
|
ciphertext, err := instanceEnc.Encrypt([]byte("secret data"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Tamper with the ciphertext body (after nonce)
|
|
tampered := make([]byte, len(ciphertext))
|
|
copy(tampered, ciphertext)
|
|
tampered[nonceSize+5] ^= 0xFF
|
|
|
|
_, err = appEnc.Decrypt(tampered)
|
|
if err == nil {
|
|
t.Error("expected error decrypting tampered ciphertext")
|
|
}
|
|
}
|
|
|
|
func TestSignAndVerifyKeyExchange(t *testing.T) {
|
|
// Generate identity keypair
|
|
privB64, pubB64, _, err := GenerateIdentityKeyPair()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ephemeralPub := []byte("32-byte-ephemeral-public-key!!!!") // 32 bytes
|
|
|
|
sig, err := SignKeyExchange(ephemeralPub, privB64)
|
|
if err != nil {
|
|
t.Fatalf("SignKeyExchange: %v", err)
|
|
}
|
|
if len(sig) != 64 {
|
|
t.Errorf("signature length: got %d, want 64", len(sig))
|
|
}
|
|
|
|
// Verify succeeds
|
|
if err := VerifyKeyExchangeSignature(ephemeralPub, sig, pubB64); err != nil {
|
|
t.Errorf("VerifyKeyExchangeSignature: %v", err)
|
|
}
|
|
|
|
// Tampered data fails
|
|
tampered := make([]byte, len(ephemeralPub))
|
|
copy(tampered, ephemeralPub)
|
|
tampered[0] ^= 0xFF
|
|
if err := VerifyKeyExchangeSignature(tampered, sig, pubB64); err == nil {
|
|
t.Error("expected verification to fail on tampered data")
|
|
}
|
|
|
|
// Tampered signature fails
|
|
badSig := make([]byte, len(sig))
|
|
copy(badSig, sig)
|
|
badSig[0] ^= 0xFF
|
|
if err := VerifyKeyExchangeSignature(ephemeralPub, badSig, pubB64); err == nil {
|
|
t.Error("expected verification to fail on tampered signature")
|
|
}
|
|
}
|
|
|
|
func TestMarshalUnmarshalKeyExchangePayload(t *testing.T) {
|
|
// With signature
|
|
pub := bytes.Repeat([]byte{0xAA}, 32)
|
|
sig := bytes.Repeat([]byte{0xBB}, 64)
|
|
|
|
data := MarshalKeyExchangePayload(pub, sig)
|
|
gotPub, gotSig, err := UnmarshalKeyExchangePayload(data)
|
|
if err != nil {
|
|
t.Fatalf("UnmarshalKeyExchangePayload (with sig): %v", err)
|
|
}
|
|
if !bytes.Equal(pub, gotPub) {
|
|
t.Error("public key mismatch")
|
|
}
|
|
if !bytes.Equal(sig, gotSig) {
|
|
t.Error("signature mismatch")
|
|
}
|
|
|
|
// Without signature
|
|
data2 := MarshalKeyExchangePayload(pub, nil)
|
|
gotPub2, gotSig2, err := UnmarshalKeyExchangePayload(data2)
|
|
if err != nil {
|
|
t.Fatalf("UnmarshalKeyExchangePayload (no sig): %v", err)
|
|
}
|
|
if !bytes.Equal(pub, gotPub2) {
|
|
t.Error("public key mismatch (no sig)")
|
|
}
|
|
if gotSig2 != nil {
|
|
t.Error("expected nil signature")
|
|
}
|
|
|
|
// Too short
|
|
_, _, err = UnmarshalKeyExchangePayload([]byte{0x01})
|
|
if err != ErrKeyExchangeTooShort {
|
|
t.Errorf("expected ErrKeyExchangeTooShort, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDeriveChannelKeys_DirectionalKeys(t *testing.T) {
|
|
instancePriv, _ := GenerateEphemeralKeyPair()
|
|
appPriv, _ := GenerateEphemeralKeyPair()
|
|
|
|
instanceEnc, err := DeriveChannelKeys(instancePriv, appPriv.PublicKey(), true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Encrypt in both directions with the same plaintext
|
|
plaintext := []byte("test message")
|
|
|
|
sendCt, err := instanceEnc.Encrypt(plaintext)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// The send cipher (instance→app) and recv cipher (app→instance) use different keys.
|
|
// Verify by checking that the recv cipher cannot decrypt what the send cipher produced.
|
|
_, err = instanceEnc.Decrypt(sendCt)
|
|
if err == nil {
|
|
t.Error("expected error: recv cipher should not decrypt send cipher output (different keys)")
|
|
}
|
|
}
|
|
|
|
// TestDeriveChannelKeys_CrossDirection verifies that two sides with swapped roles
|
|
// cannot cross-decrypt (i.e., the app's send direction matches the instance's recv).
|
|
func TestDeriveChannelKeys_CrossDirection(t *testing.T) {
|
|
priv1, _ := ecdh.X25519().GenerateKey(fakeRandReader(1))
|
|
priv2, _ := ecdh.X25519().GenerateKey(fakeRandReader(2))
|
|
|
|
enc1, err := DeriveChannelKeys(priv1, priv2.PublicKey(), true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
enc2, err := DeriveChannelKeys(priv2, priv1.PublicKey(), false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// enc1.send (instance→app) should be decryptable by enc2.recv (instance→app)
|
|
ct, err := enc1.Encrypt([]byte("from instance"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pt, err := enc2.Decrypt(ct)
|
|
if err != nil {
|
|
t.Fatalf("cross-direction decrypt failed: %v", err)
|
|
}
|
|
if !bytes.Equal(pt, []byte("from instance")) {
|
|
t.Errorf("cross-direction: got %q", pt)
|
|
}
|
|
}
|
|
|
|
// fakeRandReader returns a deterministic reader for testing. NOT cryptographically secure.
|
|
func fakeRandReader(seed byte) *deterministicReader {
|
|
return &deterministicReader{val: seed}
|
|
}
|
|
|
|
type deterministicReader struct {
|
|
val byte
|
|
}
|
|
|
|
func (r *deterministicReader) Read(p []byte) (int, error) {
|
|
for i := range p {
|
|
p[i] = r.val
|
|
r.val = r.val*7 + 13 // simple LCG-style mutation
|
|
}
|
|
return len(p), nil
|
|
}
|