package ownwire_sdk import ( "crypto/aes" "crypto/cipher" "crypto/ecdh" "crypto/hkdf" "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "fmt" ) const ( hkdf_info_prefix = "ownwire/v1:" nonce_label = "ownwire/v1:gcm-nonce" ) type Keypair struct { ClientPriv *ecdh.PrivateKey ClientPub []byte // 65 bytes, uncompressed } func GenClientKey() (Keypair, error) { curve := ecdh.P256() client_priv, err := curve.GenerateKey(rand.Reader) if err != nil { return Keypair{}, err } client_pub := client_priv.PublicKey().Bytes() if len(client_pub) != 65 { return Keypair{}, fmt.Errorf("unexpected P-256 pubkey length: %d", len(client_pub)) } return Keypair{ ClientPriv: client_priv, ClientPub: client_pub, }, nil } // DeriveSharedKey does: // ECDH -> 32 bytes shared secret // HKDF-SHA256(salt, info="ownwire/v1:") -> 32 bytes func DeriveSharedKey(session_id string, client_priv *ecdh.PrivateKey, server_pub_raw []byte, salt []byte) ([32]byte, error) { var out [32]byte if client_priv == nil { return out, fmt.Errorf("client_priv is nil") } curve := ecdh.P256() server_pub, err := curve.NewPublicKey(server_pub_raw) if err != nil { return out, fmt.Errorf("invalid server pubkey: %w", err) } shared_secret, err := client_priv.ECDH(server_pub) if err != nil { return out, fmt.Errorf("ecdh failed: %w", err) } info_str := hkdf_info_prefix + session_id prk, err := hkdf.Extract(sha256.New, shared_secret, salt) if err != nil { return out, fmt.Errorf("hkdf extract failed: %w", err) } okm, err := hkdf.Expand(sha256.New, prk, info_str, 32) if err != nil { return out, fmt.Errorf("hkdf expand failed: %w", err) } copy(out[:], okm) zero_bytes(shared_secret) return out, nil } // DeriveNonce computes: // HMAC-SHA256(key=shared_key, data=label + uuid_bytes + salt16 + seq_be8 + flag) // IV = first 12 bytes func DeriveNonce(shared_key [32]byte, session_uuid_bytes [16]byte, salt16 [16]byte, seq_num uint64, is_response bool) [12]byte { var iv [12]byte seq_be8 := [8]byte{} for i := 7; i >= 0; i-- { seq_be8[i] = byte(seq_num & 0xff) seq_num >>= 8 } flag := byte(0) if is_response { flag = 1 } mac := hmac.New(sha256.New, shared_key[:]) mac.Write([]byte(nonce_label)) mac.Write(session_uuid_bytes[:]) mac.Write(salt16[:]) mac.Write(seq_be8[:]) mac.Write([]byte{flag}) sum := mac.Sum(nil) // 32 bytes copy(iv[:], sum[:12]) return iv } type EncryptedPayload struct { ContentB64 string SaltHex string } func EncryptAESGCM(shared_key [32]byte, session_uuid_bytes [16]byte, plain_text []byte, seq_num uint64, is_response bool) (EncryptedPayload, error) { var salt16 [16]byte if _, err := rand.Read(salt16[:]); err != nil { return EncryptedPayload{}, err } iv := DeriveNonce(shared_key, session_uuid_bytes, salt16, seq_num, is_response) block, err := aes.NewCipher(shared_key[:]) if err != nil { return EncryptedPayload{}, err } aead, err := cipher.NewGCM(block) if err != nil { return EncryptedPayload{}, err } ct := aead.Seal(nil, iv[:], plain_text, nil) return EncryptedPayload{ ContentB64: base64.StdEncoding.EncodeToString(ct), SaltHex: hex.EncodeToString(salt16[:]), }, nil } func DecryptAESGCM(shared_key [32]byte, session_uuid_bytes [16]byte, content_b64 string, salt_hex string, seq_num uint64, is_response bool) ([]byte, error) { ct, err := base64.StdEncoding.DecodeString(content_b64) if err != nil { return nil, fmt.Errorf("invalid content base64: %w", err) } salt_raw, err := hex.DecodeString(salt_hex) if err != nil { return nil, fmt.Errorf("invalid salt hex: %w", err) } if len(salt_raw) != 16 { return nil, fmt.Errorf("invalid salt length: %d", len(salt_raw)) } var salt16 [16]byte copy(salt16[:], salt_raw) iv := DeriveNonce(shared_key, session_uuid_bytes, salt16, seq_num, is_response) block, err := aes.NewCipher(shared_key[:]) if err != nil { return nil, err } aead, err := cipher.NewGCM(block) if err != nil { return nil, err } pt, err := aead.Open(nil, iv[:], ct, nil) if err != nil { return nil, err } return pt, nil } // ParseUUIDBytes parses "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" into 16 bytes. func ParseUUIDBytes(uuid_str string) ([16]byte, error) { var out [16]byte clean := make([]byte, 0, 32) for i := 0; i < len(uuid_str); i++ { b := uuid_str[i] if b == '-' { continue } clean = append(clean, b) } if len(clean) != 32 { return out, fmt.Errorf("invalid uuid: %q", uuid_str) } decoded, err := hex.DecodeString(string(clean)) if err != nil { return out, fmt.Errorf("invalid uuid hex: %w", err) } copy(out[:], decoded) return out, nil } func zero_bytes(b []byte) { for i := range b { b[i] = 0 } }