From dc92793db9a905c9d034087e025c2e4a2ed59640 Mon Sep 17 00:00:00 2001 From: robert Date: Thu, 26 Feb 2026 17:09:01 +0000 Subject: [PATCH] Fix: don't accept unencrypted messages --- client.go | 92 ++++++++++++++++++++++++++++---------------------- client_test.go | 6 ++-- 2 files changed, 56 insertions(+), 42 deletions(-) diff --git a/client.go b/client.go index c32b03c..bfe83e4 100644 --- a/client.go +++ b/client.go @@ -185,53 +185,65 @@ func (c *Client) read_loop(ctx context.Context, pending []string) { } func (c *Client) handle_incoming_text(s string) { - if len(s) > 0 && s[0] == '/' { - // Ignore unknown commands after handshake for now. - return - } + if len(s) > 0 && s[0] == '/' { + return + } - var in incoming_frame - if err := json.Unmarshal([]byte(s), &in); err != nil { - return - } + var in incoming_frame + if err := json.Unmarshal([]byte(s), &in); err != nil { + return + } - c.mu.Lock() - if !c.ready { - c.mu.Unlock() - return - } + c.mu.Lock() + if !c.ready { + c.mu.Unlock() + return + } - shared_key := c.state.SharedKey - session_id_bytes := c.state.SessionIdBytes - c.mu.Unlock() + shared_key := c.state.SharedKey + session_id_bytes := c.state.SessionIdBytes + c.mu.Unlock() - content := in.Content + // Enforce encryption after handshake. + if !in.IsEncrypted { + c.emit(Event{ + Kind: EventError, + Err: fmt.Errorf("received unencrypted message after handshake"), + }) + return + } - if in.IsEncrypted { - plain, err := DecryptAESGCM(shared_key, session_id_bytes, in.Content, in.Salt, in.SeqNum, in.IsResponse) - if err != nil { - c.emit(Event{Kind: EventError, Err: err}) - return - } - content = string(plain) - } + plain, err := DecryptAESGCM( + shared_key, + session_id_bytes, + in.Content, + in.Salt, + in.SeqNum, + in.IsResponse, + ) + if err != nil { + c.emit(Event{Kind: EventError, Err: err}) + return + } - c.mu.Lock() - if in.SeqNum > c.state.SeqInMax { - c.state.SeqInMax = in.SeqNum - } - c.mu.Unlock() + content := string(plain) - c.emit(Event{ - Kind: EventMessage, - Message: Message{ - Content: content, - Metadata: in.Metadata, - SeqNum: in.SeqNum, - IsResponse: in.IsResponse, - CreatedAt: in.CreatedAt, - }, - }) + c.mu.Lock() + if in.SeqNum > c.state.SeqInMax { + c.state.SeqInMax = in.SeqNum + } + c.mu.Unlock() + + c.emit(Event{ + Kind: EventMessage, + Message: Message{ + Content: content, + Metadata: in.Metadata, + SeqNum: in.SeqNum, + IsResponse: in.IsResponse, + CreatedAt: in.CreatedAt, + }, + }) } func (c *Client) emit(ev Event) { diff --git a/client_test.go b/client_test.go index 083811b..eae3395 100644 --- a/client_test.go +++ b/client_test.go @@ -55,6 +55,10 @@ var _ = Describe("Client", func() { written := <-conn.write_ch Expect(written).To(HavePrefix("/create:")) + // First complete handshake + conn.read_ch <- "/session:" + session_id + ":" + server_pub_b64 + ":" + salt_b64 + ":12:34" + + // Then send encrypted history (now processed normally) history_enc, err := sdk.EncryptAESGCM(shared_key, session_id_bytes, []byte("hist"), 10, false) Expect(err).To(BeNil()) @@ -68,8 +72,6 @@ var _ = Describe("Client", func() { "created_at": int64(1), }) conn.read_ch <- string(history_json) - - conn.read_ch <- "/session:" + session_id + ":" + server_pub_b64 + ":" + salt_b64 + ":12:34" }() err = client.Connect(ctx, "")