From ac2d1887a936b6143a42f1cd68ef316db5e7a759 Mon Sep 17 00:00:00 2001 From: robert Date: Sun, 4 Jan 2026 20:39:38 +0000 Subject: [PATCH] Add client core, handshake pending frames, and in-memory Conn tests --- client.go | 226 ++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 117 ++++++++++++++++++++++++ conn_test.go | 37 ++++++++ handshake_test.go | 30 ------ message.go | 25 +++++ 5 files changed, 405 insertions(+), 30 deletions(-) create mode 100644 client.go create mode 100644 client_test.go create mode 100644 conn_test.go create mode 100644 message.go diff --git a/client.go b/client.go new file mode 100644 index 0000000..f69cc96 --- /dev/null +++ b/client.go @@ -0,0 +1,226 @@ +package ownwire_sdk + +import ( + "context" + "encoding/json" + "fmt" + "sync" +) + +type ClientOptions struct { + Conn Conn + Handshaker Handshaker + EventsBuffer int +} + +type Client struct { + conn Conn + handshaker Handshaker + + events_ch chan Event + + mu sync.Mutex + state SessionState + ready bool + closed bool + close_ch chan struct{} +} + +func NewClient(opts ClientOptions) *Client { + events_buffer := opts.EventsBuffer + if events_buffer == 0 { + events_buffer = 64 + } + + handshaker := opts.Handshaker.EnsureDefaults() + + return &Client{ + conn: opts.Conn, + handshaker: handshaker, + events_ch: make(chan Event, events_buffer), + close_ch: make(chan struct{}), + } +} + +func (c *Client) Events() <-chan Event { + return c.events_ch +} + +func (c *Client) Connect(ctx context.Context, resume_session_id string) error { + if c.conn == nil { + return fmt.Errorf("no conn configured (ws transport not added yet)") + } + + c.mu.Lock() + if c.ready { + c.mu.Unlock() + return nil + } + c.mu.Unlock() + + state, pending, err := c.handshaker.Run(ctx, c.conn, resume_session_id) + if err != nil { + return err + } + + c.mu.Lock() + c.state = state + c.ready = true + c.mu.Unlock() + + c.emit(Event{Kind: EventOpened}) + + go c.read_loop(pending) + + return nil +} + +func (c *Client) Close() { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } + c.closed = true + close(c.close_ch) + c.mu.Unlock() + + c.emit(Event{Kind: EventClosed}) +} + +func (c *Client) Send(ctx context.Context, content string, metadata string) error { + c.mu.Lock() + if !c.ready { + c.mu.Unlock() + return fmt.Errorf("client not connected") + } + + c.state.SeqOut++ + seq_num := c.state.SeqOut + + shared_key := c.state.SharedKey + session_id_bytes := c.state.SessionIdBytes + c.mu.Unlock() + + payload := outgoing_frame{ + Content: content, + Metadata: metadata, + SeqNum: seq_num, + IsEncrypted: true, + Salt: "", + } + + enc, err := EncryptAESGCM(shared_key, session_id_bytes, []byte(content), seq_num, false) + if err != nil { + return err + } + + payload.Content = enc.ContentB64 + payload.Salt = enc.SaltHex + + buf, err := json.Marshal(payload) + if err != nil { + return err + } + + return c.conn.WriteText(ctx, string(buf)) +} + +func (c *Client) read_loop(pending []string) { + for _, s := range pending { + c.handle_incoming_text(s) + } + + for { + select { + case <-c.close_ch: + return + default: + } + + s, err := c.conn.ReadText(context.Background()) + if err != nil { + c.emit(Event{Kind: EventError, Err: err}) + return + } + + c.handle_incoming_text(s) + } +} + +func (c *Client) handle_incoming_text(s string) { + if len(s) > 0 && s[0] == '/' { + // Ignore unknown commands after handshake for now. + return + } + + var in incoming_frame + if err := json.Unmarshal([]byte(s), &in); err != nil { + return + } + + c.mu.Lock() + if !c.ready { + c.mu.Unlock() + return + } + + shared_key := c.state.SharedKey + session_id_bytes := c.state.SessionIdBytes + c.mu.Unlock() + + content := in.Content + + if in.IsEncrypted { + plain, err := DecryptAESGCM(shared_key, session_id_bytes, in.Content, in.Salt, in.SeqNum, in.IsResponse) + if err != nil { + c.emit(Event{Kind: EventError, Err: err}) + return + } + content = string(plain) + } + + c.mu.Lock() + if in.SeqNum > c.state.SeqInMax { + c.state.SeqInMax = in.SeqNum + } + c.mu.Unlock() + + c.emit(Event{ + Kind: EventMessage, + Message: Message{ + Content: content, + Metadata: in.Metadata, + SeqNum: in.SeqNum, + IsResponse: in.IsResponse, + CreatedAt: in.CreatedAt, + }, + }) +} + +func (c *Client) emit(ev Event) { + select { + case c.events_ch <- ev: + default: + // Drop if user isn't consuming. + } +} + +type outgoing_frame struct { + Content string `json:"content"` + Metadata string `json:"metadata"` + SeqNum uint64 `json:"seq_num"` + IsEncrypted bool `json:"is_encrypted"` + Salt string `json:"salt"` +} + +type incoming_frame struct { + Content string `json:"content"` + Metadata string `json:"metadata"` + SeqNum uint64 `json:"seq_num"` + IsEncrypted bool `json:"is_encrypted"` + IsResponse bool `json:"is_response"` + Salt string `json:"salt"` + CreatedAt int64 `json:"created_at"` +} + diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..083811b --- /dev/null +++ b/client_test.go @@ -0,0 +1,117 @@ +package ownwire_sdk_test + +import ( + "context" + "crypto/ecdh" + "crypto/rand" + "encoding/base64" + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + sdk "ownwire.net/ownwire-sdk" +) + +var _ = Describe("Client", func() { + It("connects, emits opened, handles pending frame, and Send writes encrypted JSON", func() { + ctx := context.Background() + conn := sdk_test_new_inmem_conn() + + client_kp, err := sdk.GenClientKey() + Expect(err).To(BeNil()) + + curve := ecdh.P256() + server_priv, err := curve.GenerateKey(rand.Reader) + Expect(err).To(BeNil()) + server_pub_raw := server_priv.PublicKey().Bytes() + server_pub_b64 := base64.StdEncoding.EncodeToString(server_pub_raw) + + salt_raw := make([]byte, 32) + _, err = rand.Read(salt_raw) + Expect(err).To(BeNil()) + salt_b64 := base64.StdEncoding.EncodeToString(salt_raw) + + session_id := "cb653f53-6f7d-4aeb-ba0d-d2b17c290d8a" + + shared_key, err := sdk.DeriveSharedKey(session_id, client_kp.ClientPriv, server_pub_raw, salt_raw) + Expect(err).To(BeNil()) + + session_id_bytes, err := sdk.ParseUUIDBytes(session_id) + Expect(err).To(BeNil()) + + opts := sdk.ClientOptions{ + Conn: conn, + Handshaker: sdk.Handshaker{ + GenClientKeyF: func() (sdk.Keypair, error) { + return client_kp, nil + }, + }, + } + client := sdk.NewClient(opts) + + go func() { + written := <-conn.write_ch + Expect(written).To(HavePrefix("/create:")) + + history_enc, err := sdk.EncryptAESGCM(shared_key, session_id_bytes, []byte("hist"), 10, false) + Expect(err).To(BeNil()) + + history_json, _ := json.Marshal(map[string]any{ + "content": history_enc.ContentB64, + "metadata": "", + "seq_num": 10, + "is_encrypted": true, + "is_response": false, + "salt": history_enc.SaltHex, + "created_at": int64(1), + }) + conn.read_ch <- string(history_json) + + conn.read_ch <- "/session:" + session_id + ":" + server_pub_b64 + ":" + salt_b64 + ":12:34" + }() + + err = client.Connect(ctx, "") + Expect(err).To(BeNil()) + + select { + case ev := <-client.Events(): + Expect(ev.Kind).To(Equal(sdk.EventOpened)) + case <-time.After(2 * time.Second): + Fail("timeout waiting for opened event") + } + + select { + case ev := <-client.Events(): + Expect(ev.Kind).To(Equal(sdk.EventMessage)) + Expect(ev.Message.Content).To(Equal("hist")) + Expect(ev.Message.SeqNum).To(Equal(uint64(10))) + case <-time.After(2 * time.Second): + Fail("timeout waiting for history message event") + } + + go func() { + written := <-conn.write_ch + + var payload map[string]any + err := json.Unmarshal([]byte(written), &payload) + Expect(err).To(BeNil()) + + Expect(payload["is_encrypted"]).To(Equal(true)) + Expect(payload["salt"]).ToNot(BeEmpty()) + + seq_num := uint64(payload["seq_num"].(float64)) + content_b64 := payload["content"].(string) + salt_hex := payload["salt"].(string) + + pt, err := sdk.DecryptAESGCM(shared_key, session_id_bytes, content_b64, salt_hex, seq_num, false) + Expect(err).To(BeNil()) + Expect(string(pt)).To(Equal("hello")) + }() + + err = client.Send(ctx, "hello", "") + Expect(err).To(BeNil()) + }) +}) + diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 0000000..0f21a1f --- /dev/null +++ b/conn_test.go @@ -0,0 +1,37 @@ +package ownwire_sdk_test + +import "context" + +// sdk_test_inmem_conn is a test-only in-memory Conn implementation. +// It lets tests drive protocol logic deterministically without a real websocket +// or a server. +type sdk_test_inmem_conn struct { + write_ch chan string + read_ch chan string +} + +func sdk_test_new_inmem_conn() *sdk_test_inmem_conn { + return &sdk_test_inmem_conn{ + write_ch: make(chan string, 16), + read_ch: make(chan string, 16), + } +} + +func (c *sdk_test_inmem_conn) WriteText(ctx context.Context, s string) error { + select { + case c.write_ch <- s: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *sdk_test_inmem_conn) ReadText(ctx context.Context) (string, error) { + select { + case s := <-c.read_ch: + return s, nil + case <-ctx.Done(): + return "", ctx.Err() + } +} + diff --git a/handshake_test.go b/handshake_test.go index a0da947..7e0a9d2 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -100,33 +100,3 @@ var _ = Describe("Handshaker", func() { }) }) -func sdk_test_new_inmem_conn() *sdk_test_inmem_conn { - return &sdk_test_inmem_conn{ - write_ch: make(chan string, 16), - read_ch: make(chan string, 16), - } -} - -type sdk_test_inmem_conn struct { - write_ch chan string - read_ch chan string -} - -func (c *sdk_test_inmem_conn) WriteText(ctx context.Context, s string) error { - select { - case c.write_ch <- s: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -func (c *sdk_test_inmem_conn) ReadText(ctx context.Context) (string, error) { - select { - case s := <-c.read_ch: - return s, nil - case <-ctx.Done(): - return "", ctx.Err() - } -} - diff --git a/message.go b/message.go new file mode 100644 index 0000000..c162873 --- /dev/null +++ b/message.go @@ -0,0 +1,25 @@ +package ownwire_sdk + +type EventKind uint8 + +const ( + EventOpened EventKind = iota + 1 + EventMessage + EventError + EventClosed +) + +type Event struct { + Kind EventKind + Message Message + Err error +} + +type Message struct { + Content string + Metadata string + SeqNum uint64 + IsResponse bool + CreatedAt int64 +} +