Files
ownwire-go-sdk/crypto.go

210 lines
4.6 KiB
Go
Raw Normal View History

package ownwire_sdk
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/hkdf"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
)
const (
hkdf_info_prefix = "ownwire/v1:"
nonce_label = "ownwire/v1:gcm-nonce"
)
type Keypair struct {
ClientPriv *ecdh.PrivateKey
ClientPub []byte // 65 bytes, uncompressed
}
func GenClientKey() (Keypair, error) {
curve := ecdh.P256()
client_priv, err := curve.GenerateKey(rand.Reader)
if err != nil {
return Keypair{}, err
}
client_pub := client_priv.PublicKey().Bytes()
if len(client_pub) != 65 {
return Keypair{}, fmt.Errorf("unexpected P-256 pubkey length: %d", len(client_pub))
}
return Keypair{
ClientPriv: client_priv,
ClientPub: client_pub,
}, nil
}
// DeriveSharedKey does:
// ECDH -> 32 bytes shared secret
// HKDF-SHA256(salt, info="ownwire/v1:<session_id>") -> 32 bytes
func DeriveSharedKey(session_id string, client_priv *ecdh.PrivateKey, server_pub_raw []byte, salt []byte) ([32]byte, error) {
var out [32]byte
if client_priv == nil {
return out, fmt.Errorf("client_priv is nil")
}
curve := ecdh.P256()
server_pub, err := curve.NewPublicKey(server_pub_raw)
if err != nil {
return out, fmt.Errorf("invalid server pubkey: %w", err)
}
shared_secret, err := client_priv.ECDH(server_pub)
if err != nil {
return out, fmt.Errorf("ecdh failed: %w", err)
}
info_str := hkdf_info_prefix + session_id
prk, err := hkdf.Extract(sha256.New, shared_secret, salt)
if err != nil {
return out, fmt.Errorf("hkdf extract failed: %w", err)
}
okm, err := hkdf.Expand(sha256.New, prk, info_str, 32)
if err != nil {
return out, fmt.Errorf("hkdf expand failed: %w", err)
}
copy(out[:], okm)
zero_bytes(shared_secret)
return out, nil
}
// DeriveNonce computes:
// HMAC-SHA256(key=shared_key, data=label + uuid_bytes + salt16 + seq_be8 + flag)
// IV = first 12 bytes
func DeriveNonce(shared_key [32]byte, session_uuid_bytes [16]byte, salt16 [16]byte, seq_num uint64, is_response bool) [12]byte {
var iv [12]byte
seq_be8 := [8]byte{}
for i := 7; i >= 0; i-- {
seq_be8[i] = byte(seq_num & 0xff)
seq_num >>= 8
}
flag := byte(0)
if is_response {
flag = 1
}
mac := hmac.New(sha256.New, shared_key[:])
mac.Write([]byte(nonce_label))
mac.Write(session_uuid_bytes[:])
mac.Write(salt16[:])
mac.Write(seq_be8[:])
mac.Write([]byte{flag})
sum := mac.Sum(nil) // 32 bytes
copy(iv[:], sum[:12])
return iv
}
type EncryptedPayload struct {
ContentB64 string
SaltHex string
}
func EncryptAESGCM(shared_key [32]byte, session_uuid_bytes [16]byte, plain_text []byte, seq_num uint64, is_response bool) (EncryptedPayload, error) {
var salt16 [16]byte
if _, err := rand.Read(salt16[:]); err != nil {
return EncryptedPayload{}, err
}
iv := DeriveNonce(shared_key, session_uuid_bytes, salt16, seq_num, is_response)
block, err := aes.NewCipher(shared_key[:])
if err != nil {
return EncryptedPayload{}, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return EncryptedPayload{}, err
}
ct := aead.Seal(nil, iv[:], plain_text, nil)
return EncryptedPayload{
ContentB64: base64.StdEncoding.EncodeToString(ct),
SaltHex: hex.EncodeToString(salt16[:]),
}, nil
}
func DecryptAESGCM(shared_key [32]byte, session_uuid_bytes [16]byte, content_b64 string, salt_hex string, seq_num uint64, is_response bool) ([]byte, error) {
ct, err := base64.StdEncoding.DecodeString(content_b64)
if err != nil {
return nil, fmt.Errorf("invalid content base64: %w", err)
}
salt_raw, err := hex.DecodeString(salt_hex)
if err != nil {
return nil, fmt.Errorf("invalid salt hex: %w", err)
}
if len(salt_raw) != 16 {
return nil, fmt.Errorf("invalid salt length: %d", len(salt_raw))
}
var salt16 [16]byte
copy(salt16[:], salt_raw)
iv := DeriveNonce(shared_key, session_uuid_bytes, salt16, seq_num, is_response)
block, err := aes.NewCipher(shared_key[:])
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
pt, err := aead.Open(nil, iv[:], ct, nil)
if err != nil {
return nil, err
}
return pt, nil
}
// ParseUUIDBytes parses "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" into 16 bytes.
func ParseUUIDBytes(uuid_str string) ([16]byte, error) {
var out [16]byte
clean := make([]byte, 0, 32)
for i := 0; i < len(uuid_str); i++ {
b := uuid_str[i]
if b == '-' {
continue
}
clean = append(clean, b)
}
if len(clean) != 32 {
return out, fmt.Errorf("invalid uuid: %q", uuid_str)
}
decoded, err := hex.DecodeString(string(clean))
if err != nil {
return out, fmt.Errorf("invalid uuid hex: %w", err)
}
copy(out[:], decoded)
return out, nil
}
func zero_bytes(b []byte) {
for i := range b {
b[i] = 0
}
}