diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..d2ff01a --- /dev/null +++ b/conn.go @@ -0,0 +1,10 @@ +package ownwire_sdk + +import ( + "context" +) + +type Conn interface { + WriteText(ctx context.Context, s string) error + ReadText(ctx context.Context) (string, error) +} diff --git a/conn_inmem_test.go b/conn_inmem_test.go new file mode 100644 index 0000000..e763270 --- /dev/null +++ b/conn_inmem_test.go @@ -0,0 +1,44 @@ +// This file defines an in-memory implementation of the Conn interface that is +// used ONLY by unit tests. +// +// The inmem_conn replaces a real WebSocket or network connection with simple +// Go channels. This allows handshake, sequencing, and message logic to be +// tested deterministically without spinning up a server, opening sockets, +// or relying on timing. + +package ownwire_sdk + +import ( + "context" +) + +type inmem_conn struct { + write_ch chan string + read_ch chan string +} + +func new_inmem_conn() *inmem_conn { + return &inmem_conn{ + write_ch: make(chan string, 16), + read_ch: make(chan string, 16), + } +} + +func (c *inmem_conn) WriteText(ctx context.Context, s string) error { + select { + case c.write_ch <- s: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *inmem_conn) ReadText(ctx context.Context) (string, error) { + select { + case s := <-c.read_ch: + return s, nil + case <-ctx.Done(): + return "", ctx.Err() + } +} + diff --git a/handshake.go b/handshake.go new file mode 100644 index 0000000..94fb5e5 --- /dev/null +++ b/handshake.go @@ -0,0 +1,118 @@ +package ownwire_sdk + +import ( + "context" + "encoding/base64" + "fmt" + "time" +) + +type SessionState struct { + SessionId string + ClientPubKeyB64 string + + ServerPubKeyB64 string + SaltB64 string + + SharedKey [32]byte + SessionIdBytes [16]byte + + SeqOut uint64 + SeqInMax uint64 +} + +type Handshaker struct { + Timeout time.Duration + GenClientKeyF func() (Keypair, error) +} + +func (h Handshaker) EnsureDefaults() Handshaker { + if h.Timeout == 0 { + h.Timeout = 10 * time.Second + } + if h.GenClientKeyF == nil { + h.GenClientKeyF = GenClientKey + } + return h +} + +func (h Handshaker) Run(ctx context.Context, conn Conn, resume_session_id string) (SessionState, []string, error) { + h = h.EnsureDefaults() + + timeout_ctx, cancel := context.WithTimeout(ctx, h.Timeout) + defer cancel() + + keypair, err := h.GenClientKeyF() + if err != nil { + return SessionState{}, nil, err + } + + client_pub_b64 := base64.StdEncoding.EncodeToString(keypair.ClientPub) + + if resume_session_id != "" { + open_cmd := fmt.Sprintf("/open:%s:%s", resume_session_id, client_pub_b64) + if err := conn.WriteText(timeout_ctx, open_cmd); err != nil { + return SessionState{}, nil, err + } + } else { + create_cmd := fmt.Sprintf("/create:%s", client_pub_b64) + if err := conn.WriteText(timeout_ctx, create_cmd); err != nil { + return SessionState{}, nil, err + } + } + + pending := make([]string, 0, 8) + + for { + line, err := conn.ReadText(timeout_ctx) + if err != nil { + return SessionState{}, nil, err + } + + if len(line) >= 9 && line[:9] == "/session:" { + parsed, err := ParseSessionInit(line) + if err != nil { + return SessionState{}, nil, err + } + + server_pub_raw, err := base64.StdEncoding.DecodeString(parsed.ServerPubKeyB64) + if err != nil { + return SessionState{}, nil, fmt.Errorf("invalid server pubkey base64: %w", err) + } + + salt_raw, err := base64.StdEncoding.DecodeString(parsed.SaltB64) + if err != nil { + return SessionState{}, nil, fmt.Errorf("invalid salt base64: %w", err) + } + + shared_key, err := DeriveSharedKey(parsed.SessionId, keypair.ClientPriv, server_pub_raw, salt_raw) + if err != nil { + return SessionState{}, nil, err + } + + session_id_bytes, err := ParseUUIDBytes(parsed.SessionId) + if err != nil { + return SessionState{}, nil, err + } + + out := SessionState{ + SessionId: parsed.SessionId, + ClientPubKeyB64: client_pub_b64, + + ServerPubKeyB64: parsed.ServerPubKeyB64, + SaltB64: parsed.SaltB64, + + SharedKey: shared_key, + SessionIdBytes: session_id_bytes, + + SeqOut: parsed.SeqInMax, + SeqInMax: parsed.SeqOut, + } + + return out, pending, nil + } + + pending = append(pending, line) + } +} + diff --git a/handshake_test.go b/handshake_test.go new file mode 100644 index 0000000..a0da947 --- /dev/null +++ b/handshake_test.go @@ -0,0 +1,132 @@ +package ownwire_sdk_test + +import ( + "context" + "crypto/ecdh" + "crypto/rand" + "encoding/base64" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + sdk "ownwire.net/ownwire-sdk" +) + +var _ = Describe("Handshaker", func() { + It("sends /create when no resume session_id is provided and derives shared key", func() { + ctx := context.Background() + conn := sdk_test_new_inmem_conn() + + client_kp, err := sdk.GenClientKey() + Expect(err).To(BeNil()) + + h := sdk.Handshaker{ + GenClientKeyF: func() (sdk.Keypair, error) { + return client_kp, nil + }, + } + + curve := ecdh.P256() + server_priv, err := curve.GenerateKey(rand.Reader) + Expect(err).To(BeNil()) + server_pub_raw := server_priv.PublicKey().Bytes() + server_pub_b64 := base64.StdEncoding.EncodeToString(server_pub_raw) + + salt_raw := make([]byte, 32) + _, err = rand.Read(salt_raw) + Expect(err).To(BeNil()) + salt_b64 := base64.StdEncoding.EncodeToString(salt_raw) + + session_id := "cb653f53-6f7d-4aeb-ba0d-d2b17c290d8a" + + go func() { + written := <-conn.write_ch + Expect(written).To(HavePrefix("/create:")) + + conn.read_ch <- "/session:" + session_id + ":" + server_pub_b64 + ":" + salt_b64 + ":12:34" + }() + + state, pending, err := h.Run(ctx, conn, "") + Expect(err).To(BeNil()) + Expect(pending).To(BeEmpty()) + + Expect(state.SessionId).To(Equal(session_id)) + Expect(state.ClientPubKeyB64).ToNot(BeEmpty()) + + want_key, err := sdk.DeriveSharedKey(session_id, client_kp.ClientPriv, server_pub_raw, salt_raw) + Expect(err).To(BeNil()) + Expect(state.SharedKey).To(Equal(want_key)) + + Expect(state.SeqOut).To(Equal(uint64(34))) + Expect(state.SeqInMax).To(Equal(uint64(12))) + }) + + It("sends /open when resume session_id is provided", func() { + ctx := context.Background() + conn := sdk_test_new_inmem_conn() + + client_kp, err := sdk.GenClientKey() + Expect(err).To(BeNil()) + + h := sdk.Handshaker{ + GenClientKeyF: func() (sdk.Keypair, error) { + return client_kp, nil + }, + } + + curve := ecdh.P256() + server_priv, err := curve.GenerateKey(rand.Reader) + Expect(err).To(BeNil()) + server_pub_raw := server_priv.PublicKey().Bytes() + server_pub_b64 := base64.StdEncoding.EncodeToString(server_pub_raw) + + salt_raw := make([]byte, 32) + _, err = rand.Read(salt_raw) + Expect(err).To(BeNil()) + salt_b64 := base64.StdEncoding.EncodeToString(salt_raw) + + session_id := "cb653f53-6f7d-4aeb-ba0d-d2b17c290d8a" + + go func() { + written := <-conn.write_ch + Expect(written).To(HavePrefix("/open:" + session_id + ":")) + + conn.read_ch <- "/session:" + session_id + ":" + server_pub_b64 + ":" + salt_b64 + ":1:2" + }() + + _, pending, err := h.Run(ctx, conn, session_id) + Expect(err).To(BeNil()) + Expect(pending).To(BeEmpty()) + }) +}) + +func sdk_test_new_inmem_conn() *sdk_test_inmem_conn { + return &sdk_test_inmem_conn{ + write_ch: make(chan string, 16), + read_ch: make(chan string, 16), + } +} + +type sdk_test_inmem_conn struct { + write_ch chan string + read_ch chan string +} + +func (c *sdk_test_inmem_conn) WriteText(ctx context.Context, s string) error { + select { + case c.write_ch <- s: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *sdk_test_inmem_conn) ReadText(ctx context.Context) (string, error) { + select { + case s := <-c.read_ch: + return s, nil + case <-ctx.Done(): + return "", ctx.Err() + } +} +