mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-02-19 07:50:43 +01:00
256 lines
7.5 KiB
Go
256 lines
7.5 KiB
Go
package relay
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/ecdh"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"crypto/sha256"
|
|
"golang.org/x/crypto/hkdf"
|
|
"io"
|
|
)
|
|
|
|
const (
|
|
// nonceSize is the AES-256-GCM nonce size (12 bytes).
|
|
nonceSize = 12
|
|
|
|
// hkdfInfoAppToInstance is the HKDF info string for the app→instance direction.
|
|
hkdfInfoAppToInstance = "relay-e2e-app-to-instance"
|
|
|
|
// hkdfInfoInstanceToApp is the HKDF info string for the instance→app direction.
|
|
hkdfInfoInstanceToApp = "relay-e2e-instance-to-app"
|
|
|
|
// aesKeySize is the AES-256 key size in bytes.
|
|
aesKeySize = 32
|
|
)
|
|
|
|
var (
|
|
ErrNonceOverflow = errors.New("nonce counter overflow")
|
|
ErrCiphertextTooShort = errors.New("ciphertext too short: need at least nonce + tag")
|
|
ErrKeyExchangeTooShort = errors.New("key exchange payload too short")
|
|
ErrNonceReplay = errors.New("nonce replay or out-of-order: expected higher nonce")
|
|
ErrIdentityKeyRequired = errors.New("identity private key required for key exchange signing")
|
|
)
|
|
|
|
// channelCipher holds the encryption state for one channel direction.
|
|
type channelCipher struct {
|
|
aead cipher.AEAD
|
|
nonce uint64 // send-side: next nonce to use
|
|
recvNonce uint64 // recv-side: next expected nonce (must be strictly monotonic)
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// ChannelEncryption holds the full encryption state for a channel.
|
|
type ChannelEncryption struct {
|
|
sendCipher *channelCipher // outbound direction
|
|
recvCipher *channelCipher // inbound direction
|
|
}
|
|
|
|
// GenerateEphemeralKeyPair creates an X25519 keypair for key exchange.
|
|
func GenerateEphemeralKeyPair() (*ecdh.PrivateKey, error) {
|
|
return ecdh.X25519().GenerateKey(rand.Reader)
|
|
}
|
|
|
|
// DeriveChannelKeys performs ECDH + HKDF to produce directional AES-256-GCM ciphers.
|
|
// iAmInstance determines which direction maps to send vs recv:
|
|
// - instance: send = instance→app, recv = app→instance
|
|
// - app: send = app→instance, recv = instance→app
|
|
func DeriveChannelKeys(myPrivate *ecdh.PrivateKey, theirPublic *ecdh.PublicKey, iAmInstance bool) (*ChannelEncryption, error) {
|
|
sharedSecret, err := myPrivate.ECDH(theirPublic)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ecdh: %w", err)
|
|
}
|
|
|
|
// Derive app→instance key
|
|
a2iKey, err := deriveKey(sharedSecret, hkdfInfoAppToInstance)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("derive app→instance key: %w", err)
|
|
}
|
|
|
|
// Derive instance→app key
|
|
i2aKey, err := deriveKey(sharedSecret, hkdfInfoInstanceToApp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("derive instance→app key: %w", err)
|
|
}
|
|
|
|
a2iCipher, err := newChannelCipher(a2iKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create app→instance cipher: %w", err)
|
|
}
|
|
|
|
i2aCipher, err := newChannelCipher(i2aKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create instance→app cipher: %w", err)
|
|
}
|
|
|
|
if iAmInstance {
|
|
return &ChannelEncryption{
|
|
sendCipher: i2aCipher, // instance sends on instance→app
|
|
recvCipher: a2iCipher, // instance receives on app→instance
|
|
}, nil
|
|
}
|
|
return &ChannelEncryption{
|
|
sendCipher: a2iCipher, // app sends on app→instance
|
|
recvCipher: i2aCipher, // app receives on instance→app
|
|
}, nil
|
|
}
|
|
|
|
func deriveKey(secret []byte, info string) ([]byte, error) {
|
|
hkdfReader := hkdf.New(sha256.New, secret, nil, []byte(info))
|
|
key := make([]byte, aesKeySize)
|
|
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
|
return nil, fmt.Errorf("hkdf read: %w", err)
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func newChannelCipher(key []byte) (*channelCipher, error) {
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
aead, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &channelCipher{aead: aead}, nil
|
|
}
|
|
|
|
// Encrypt seals plaintext with an incrementing nonce.
|
|
// Output format: [12-byte nonce][ciphertext + 16-byte GCM tag]
|
|
func (ce *ChannelEncryption) Encrypt(plaintext []byte) ([]byte, error) {
|
|
c := ce.sendCipher
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
nonce, err := c.nextNonce()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
|
|
|
|
// Prepend nonce
|
|
out := make([]byte, nonceSize+len(ciphertext))
|
|
copy(out[:nonceSize], nonce)
|
|
copy(out[nonceSize:], ciphertext)
|
|
return out, nil
|
|
}
|
|
|
|
// Decrypt opens ciphertext in the format [12-byte nonce][ciphertext + tag].
|
|
// It enforces strict nonce monotonicity to prevent replay attacks: each
|
|
// received nonce must be strictly greater than or equal to the expected
|
|
// next nonce. The expected nonce advances after each successful decryption.
|
|
func (ce *ChannelEncryption) Decrypt(data []byte) ([]byte, error) {
|
|
c := ce.recvCipher
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
minSize := nonceSize + c.aead.Overhead()
|
|
if len(data) < minSize {
|
|
return nil, ErrCiphertextTooShort
|
|
}
|
|
|
|
nonce := data[:nonceSize]
|
|
ciphertext := data[nonceSize:]
|
|
|
|
// Extract nonce counter (little-endian uint64 in first 8 bytes)
|
|
receivedNonce := binary.LittleEndian.Uint64(nonce[:8])
|
|
if receivedNonce < c.recvNonce {
|
|
return nil, ErrNonceReplay
|
|
}
|
|
|
|
plaintext, err := c.aead.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gcm open: %w", err)
|
|
}
|
|
|
|
// Only advance after successful decryption
|
|
c.recvNonce = receivedNonce + 1
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
// nextNonce returns the next incrementing nonce and advances the counter.
|
|
func (c *channelCipher) nextNonce() ([]byte, error) {
|
|
n := c.nonce
|
|
if n == ^uint64(0) {
|
|
return nil, ErrNonceOverflow
|
|
}
|
|
c.nonce++
|
|
|
|
nonce := make([]byte, nonceSize)
|
|
binary.LittleEndian.PutUint64(nonce[:8], n)
|
|
// bytes 8-11 remain zero (upper 32 bits of uint96)
|
|
return nonce, nil
|
|
}
|
|
|
|
// SignKeyExchange signs an ephemeral public key with the Ed25519 identity key.
|
|
func SignKeyExchange(ephemeralPub []byte, identityPrivateKeyB64 string) ([]byte, error) {
|
|
privBytes, err := base64.StdEncoding.DecodeString(identityPrivateKeyB64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode private key: %w", err)
|
|
}
|
|
if len(privBytes) != ed25519.PrivateKeySize {
|
|
return nil, fmt.Errorf("invalid private key length: got %d, want %d", len(privBytes), ed25519.PrivateKeySize)
|
|
}
|
|
|
|
privKey := ed25519.PrivateKey(privBytes)
|
|
sig := ed25519.Sign(privKey, ephemeralPub)
|
|
return sig, nil
|
|
}
|
|
|
|
// VerifyKeyExchangeSignature verifies the Ed25519 signature on a KEY_EXCHANGE.
|
|
func VerifyKeyExchangeSignature(ephemeralPub, signature []byte, identityPublicKeyB64 string) error {
|
|
pubBytes, err := base64.StdEncoding.DecodeString(identityPublicKeyB64)
|
|
if err != nil {
|
|
return fmt.Errorf("decode public key: %w", err)
|
|
}
|
|
if len(pubBytes) != ed25519.PublicKeySize {
|
|
return fmt.Errorf("invalid public key length: got %d, want %d", len(pubBytes), ed25519.PublicKeySize)
|
|
}
|
|
|
|
pubKey := ed25519.PublicKey(pubBytes)
|
|
if !ed25519.Verify(pubKey, ephemeralPub, signature) {
|
|
return errors.New("key exchange signature verification failed")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MarshalKeyExchangePayload encodes a KEY_EXCHANGE payload as binary.
|
|
// Wire format: [1 byte pubkey len][pubkey][signature or nothing]
|
|
func MarshalKeyExchangePayload(pub []byte, sig []byte) []byte {
|
|
out := make([]byte, 1+len(pub)+len(sig))
|
|
out[0] = byte(len(pub))
|
|
copy(out[1:1+len(pub)], pub)
|
|
if len(sig) > 0 {
|
|
copy(out[1+len(pub):], sig)
|
|
}
|
|
return out
|
|
}
|
|
|
|
// UnmarshalKeyExchangePayload decodes a KEY_EXCHANGE payload.
|
|
func UnmarshalKeyExchangePayload(data []byte) (pub []byte, sig []byte, err error) {
|
|
if len(data) < 2 {
|
|
return nil, nil, ErrKeyExchangeTooShort
|
|
}
|
|
|
|
pubLen := int(data[0])
|
|
if len(data) < 1+pubLen {
|
|
return nil, nil, ErrKeyExchangeTooShort
|
|
}
|
|
|
|
pub = data[1 : 1+pubLen]
|
|
if len(data) > 1+pubLen {
|
|
sig = data[1+pubLen:]
|
|
}
|
|
return pub, sig, nil
|
|
}
|