Implemented conn.go and hanshake.go + unit tests
This commit is contained in:
10
conn.go
Normal file
10
conn.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package ownwire_sdk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conn interface {
|
||||||
|
WriteText(ctx context.Context, s string) error
|
||||||
|
ReadText(ctx context.Context) (string, error)
|
||||||
|
}
|
||||||
44
conn_inmem_test.go
Normal file
44
conn_inmem_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// This file defines an in-memory implementation of the Conn interface that is
|
||||||
|
// used ONLY by unit tests.
|
||||||
|
//
|
||||||
|
// The inmem_conn replaces a real WebSocket or network connection with simple
|
||||||
|
// Go channels. This allows handshake, sequencing, and message logic to be
|
||||||
|
// tested deterministically without spinning up a server, opening sockets,
|
||||||
|
// or relying on timing.
|
||||||
|
|
||||||
|
package ownwire_sdk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type inmem_conn struct {
|
||||||
|
write_ch chan string
|
||||||
|
read_ch chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func new_inmem_conn() *inmem_conn {
|
||||||
|
return &inmem_conn{
|
||||||
|
write_ch: make(chan string, 16),
|
||||||
|
read_ch: make(chan string, 16),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *inmem_conn) WriteText(ctx context.Context, s string) error {
|
||||||
|
select {
|
||||||
|
case c.write_ch <- s:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *inmem_conn) ReadText(ctx context.Context) (string, error) {
|
||||||
|
select {
|
||||||
|
case s := <-c.read_ch:
|
||||||
|
return s, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
118
handshake.go
Normal file
118
handshake.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package ownwire_sdk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionState struct {
|
||||||
|
SessionId string
|
||||||
|
ClientPubKeyB64 string
|
||||||
|
|
||||||
|
ServerPubKeyB64 string
|
||||||
|
SaltB64 string
|
||||||
|
|
||||||
|
SharedKey [32]byte
|
||||||
|
SessionIdBytes [16]byte
|
||||||
|
|
||||||
|
SeqOut uint64
|
||||||
|
SeqInMax uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handshaker struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
GenClientKeyF func() (Keypair, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Handshaker) EnsureDefaults() Handshaker {
|
||||||
|
if h.Timeout == 0 {
|
||||||
|
h.Timeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
if h.GenClientKeyF == nil {
|
||||||
|
h.GenClientKeyF = GenClientKey
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Handshaker) Run(ctx context.Context, conn Conn, resume_session_id string) (SessionState, []string, error) {
|
||||||
|
h = h.EnsureDefaults()
|
||||||
|
|
||||||
|
timeout_ctx, cancel := context.WithTimeout(ctx, h.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
keypair, err := h.GenClientKeyF()
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client_pub_b64 := base64.StdEncoding.EncodeToString(keypair.ClientPub)
|
||||||
|
|
||||||
|
if resume_session_id != "" {
|
||||||
|
open_cmd := fmt.Sprintf("/open:%s:%s", resume_session_id, client_pub_b64)
|
||||||
|
if err := conn.WriteText(timeout_ctx, open_cmd); err != nil {
|
||||||
|
return SessionState{}, nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
create_cmd := fmt.Sprintf("/create:%s", client_pub_b64)
|
||||||
|
if err := conn.WriteText(timeout_ctx, create_cmd); err != nil {
|
||||||
|
return SessionState{}, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pending := make([]string, 0, 8)
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := conn.ReadText(timeout_ctx)
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(line) >= 9 && line[:9] == "/session:" {
|
||||||
|
parsed, err := ParseSessionInit(line)
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
server_pub_raw, err := base64.StdEncoding.DecodeString(parsed.ServerPubKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}, nil, fmt.Errorf("invalid server pubkey base64: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
salt_raw, err := base64.StdEncoding.DecodeString(parsed.SaltB64)
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}, nil, fmt.Errorf("invalid salt base64: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_key, err := DeriveSharedKey(parsed.SessionId, keypair.ClientPriv, server_pub_raw, salt_raw)
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
session_id_bytes, err := ParseUUIDBytes(parsed.SessionId)
|
||||||
|
if err != nil {
|
||||||
|
return SessionState{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := SessionState{
|
||||||
|
SessionId: parsed.SessionId,
|
||||||
|
ClientPubKeyB64: client_pub_b64,
|
||||||
|
|
||||||
|
ServerPubKeyB64: parsed.ServerPubKeyB64,
|
||||||
|
SaltB64: parsed.SaltB64,
|
||||||
|
|
||||||
|
SharedKey: shared_key,
|
||||||
|
SessionIdBytes: session_id_bytes,
|
||||||
|
|
||||||
|
SeqOut: parsed.SeqInMax,
|
||||||
|
SeqInMax: parsed.SeqOut,
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, pending, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pending = append(pending, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
132
handshake_test.go
Normal file
132
handshake_test.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package ownwire_sdk_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
|
sdk "ownwire.net/ownwire-sdk"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Handshaker", func() {
|
||||||
|
It("sends /create when no resume session_id is provided and derives shared key", func() {
|
||||||
|
ctx := context.Background()
|
||||||
|
conn := sdk_test_new_inmem_conn()
|
||||||
|
|
||||||
|
client_kp, err := sdk.GenClientKey()
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
|
||||||
|
h := sdk.Handshaker{
|
||||||
|
GenClientKeyF: func() (sdk.Keypair, error) {
|
||||||
|
return client_kp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
curve := ecdh.P256()
|
||||||
|
server_priv, err := curve.GenerateKey(rand.Reader)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
server_pub_raw := server_priv.PublicKey().Bytes()
|
||||||
|
server_pub_b64 := base64.StdEncoding.EncodeToString(server_pub_raw)
|
||||||
|
|
||||||
|
salt_raw := make([]byte, 32)
|
||||||
|
_, err = rand.Read(salt_raw)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
salt_b64 := base64.StdEncoding.EncodeToString(salt_raw)
|
||||||
|
|
||||||
|
session_id := "cb653f53-6f7d-4aeb-ba0d-d2b17c290d8a"
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
written := <-conn.write_ch
|
||||||
|
Expect(written).To(HavePrefix("/create:"))
|
||||||
|
|
||||||
|
conn.read_ch <- "/session:" + session_id + ":" + server_pub_b64 + ":" + salt_b64 + ":12:34"
|
||||||
|
}()
|
||||||
|
|
||||||
|
state, pending, err := h.Run(ctx, conn, "")
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
Expect(pending).To(BeEmpty())
|
||||||
|
|
||||||
|
Expect(state.SessionId).To(Equal(session_id))
|
||||||
|
Expect(state.ClientPubKeyB64).ToNot(BeEmpty())
|
||||||
|
|
||||||
|
want_key, err := sdk.DeriveSharedKey(session_id, client_kp.ClientPriv, server_pub_raw, salt_raw)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
Expect(state.SharedKey).To(Equal(want_key))
|
||||||
|
|
||||||
|
Expect(state.SeqOut).To(Equal(uint64(34)))
|
||||||
|
Expect(state.SeqInMax).To(Equal(uint64(12)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends /open when resume session_id is provided", func() {
|
||||||
|
ctx := context.Background()
|
||||||
|
conn := sdk_test_new_inmem_conn()
|
||||||
|
|
||||||
|
client_kp, err := sdk.GenClientKey()
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
|
||||||
|
h := sdk.Handshaker{
|
||||||
|
GenClientKeyF: func() (sdk.Keypair, error) {
|
||||||
|
return client_kp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
curve := ecdh.P256()
|
||||||
|
server_priv, err := curve.GenerateKey(rand.Reader)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
server_pub_raw := server_priv.PublicKey().Bytes()
|
||||||
|
server_pub_b64 := base64.StdEncoding.EncodeToString(server_pub_raw)
|
||||||
|
|
||||||
|
salt_raw := make([]byte, 32)
|
||||||
|
_, err = rand.Read(salt_raw)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
salt_b64 := base64.StdEncoding.EncodeToString(salt_raw)
|
||||||
|
|
||||||
|
session_id := "cb653f53-6f7d-4aeb-ba0d-d2b17c290d8a"
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
written := <-conn.write_ch
|
||||||
|
Expect(written).To(HavePrefix("/open:" + session_id + ":"))
|
||||||
|
|
||||||
|
conn.read_ch <- "/session:" + session_id + ":" + server_pub_b64 + ":" + salt_b64 + ":1:2"
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, pending, err := h.Run(ctx, conn, session_id)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
Expect(pending).To(BeEmpty())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
func sdk_test_new_inmem_conn() *sdk_test_inmem_conn {
|
||||||
|
return &sdk_test_inmem_conn{
|
||||||
|
write_ch: make(chan string, 16),
|
||||||
|
read_ch: make(chan string, 16),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sdk_test_inmem_conn struct {
|
||||||
|
write_ch chan string
|
||||||
|
read_ch chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sdk_test_inmem_conn) WriteText(ctx context.Context, s string) error {
|
||||||
|
select {
|
||||||
|
case c.write_ch <- s:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sdk_test_inmem_conn) ReadText(ctx context.Context) (string, error) {
|
||||||
|
select {
|
||||||
|
case s := <-c.read_ch:
|
||||||
|
return s, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user