diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2b8c748 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +ownwire-cli diff --git a/client.go b/client.go index f69cc96..c32b03c 100644 --- a/client.go +++ b/client.go @@ -8,22 +8,29 @@ import ( ) type ClientOptions struct { + Url string Conn Conn Handshaker Handshaker EventsBuffer int } type Client struct { - conn Conn + url string + conn Conn handshaker Handshaker events_ch chan Event - mu sync.Mutex - state SessionState - ready bool - closed bool - close_ch chan struct{} + mu sync.Mutex + state SessionState + ready bool + closed bool + close_ch chan struct{} + read_cancel context.CancelFunc +} + +type conn_closer interface { + Close() error } func NewClient(opts ClientOptions) *Client { @@ -35,6 +42,7 @@ func NewClient(opts ClientOptions) *Client { handshaker := opts.Handshaker.EnsureDefaults() return &Client{ + url: opts.Url, conn: opts.Conn, handshaker: handshaker, events_ch: make(chan Event, events_buffer), @@ -47,10 +55,6 @@ func (c *Client) Events() <-chan Event { } func (c *Client) Connect(ctx context.Context, resume_session_id string) error { - if c.conn == nil { - return fmt.Errorf("no conn configured (ws transport not added yet)") - } - c.mu.Lock() if c.ready { c.mu.Unlock() @@ -58,19 +62,39 @@ func (c *Client) Connect(ctx context.Context, resume_session_id string) error { } c.mu.Unlock() - state, pending, err := c.handshaker.Run(ctx, c.conn, resume_session_id) + conn := c.conn + if conn == nil { + if c.url == "" { + return fmt.Errorf("no Url configured and no Conn provided") + } + + ws_conn, err := DialWs(ctx, c.url) + if err != nil { + return err + } + conn = ws_conn + + c.mu.Lock() + c.conn = conn + c.mu.Unlock() + } + + state, pending, err := c.handshaker.Run(ctx, conn, resume_session_id) if err != nil { return err } + read_ctx, read_cancel := context.WithCancel(context.Background()) + c.mu.Lock() c.state = state c.ready = true + c.read_cancel = read_cancel c.mu.Unlock() c.emit(Event{Kind: EventOpened}) - go c.read_loop(pending) + go c.read_loop(read_ctx, pending) return nil } @@ -82,9 +106,20 @@ func (c *Client) Close() { return } c.closed = true + + read_cancel := c.read_cancel + conn := c.conn close(c.close_ch) c.mu.Unlock() + if read_cancel != nil { + read_cancel() + } + + if closer, ok := conn.(conn_closer); ok { + _ = closer.Close() + } + c.emit(Event{Kind: EventClosed}) } @@ -100,6 +135,7 @@ func (c *Client) Send(ctx context.Context, content string, metadata string) erro shared_key := c.state.SharedKey session_id_bytes := c.state.SessionIdBytes + conn := c.conn c.mu.Unlock() payload := outgoing_frame{ @@ -123,10 +159,10 @@ func (c *Client) Send(ctx context.Context, content string, metadata string) erro return err } - return c.conn.WriteText(ctx, string(buf)) + return conn.WriteText(ctx, string(buf)) } -func (c *Client) read_loop(pending []string) { +func (c *Client) read_loop(ctx context.Context, pending []string) { for _, s := range pending { c.handle_incoming_text(s) } @@ -138,7 +174,7 @@ func (c *Client) read_loop(pending []string) { default: } - s, err := c.conn.ReadText(context.Background()) + s, err := c.conn.ReadText(ctx) if err != nil { c.emit(Event{Kind: EventError, Err: err}) return diff --git a/conn_ws.go b/conn_ws.go new file mode 100644 index 0000000..dbebb7e --- /dev/null +++ b/conn_ws.go @@ -0,0 +1,44 @@ +package ownwire_sdk + +import ( + "context" + "fmt" + + "nhooyr.io/websocket" +) + +type WsConn struct { + conn *websocket.Conn +} + +func DialWs(ctx context.Context, url string) (*WsConn, error) { + conn, _, err := websocket.Dial(ctx, url, nil) + if err != nil { + return nil, err + } + return &WsConn{conn: conn}, nil +} + +func (c *WsConn) WriteText(ctx context.Context, s string) error { + return c.conn.Write(ctx, websocket.MessageText, []byte(s)) +} + +func (c *WsConn) ReadText(ctx context.Context) (string, error) { + msg_type, data, err := c.conn.Read(ctx) + if err != nil { + return "", err + } + + if msg_type != websocket.MessageText { + // We only expect text frames. If something else arrives, surface it. + return "", fmt.Errorf("unexpected websocket message type: %v", msg_type) + } + + return string(data), nil +} + +func (c *WsConn) Close() error { + // Normal closure. + return c.conn.Close(websocket.StatusNormalClosure, "normal") +} + diff --git a/example/cli/main.go b/example/cli/main.go new file mode 100644 index 0000000..bfd88f1 --- /dev/null +++ b/example/cli/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + sdk "ownwire.net/ownwire-sdk" +) + +func main() { + url := flag.String("url", "", "websocket url, e.g. ws://localhost:8080/ownwire") + resume_session_id := flag.String("resume", "", "resume session id (optional)") + flag.Parse() + + if *url == "" { + fmt.Fprintln(os.Stderr, "missing -url") + os.Exit(2) + } + + client := sdk.NewClient(sdk.ClientOptions{ + Url: *url, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + if err := client.Connect(ctx, *resume_session_id); err != nil { + fmt.Fprintf(os.Stderr, "connect failed: %v\n", err) + os.Exit(1) + } + + fmt.Println("connected") + + sig_ch := make(chan os.Signal, 1) + signal.Notify(sig_ch, syscall.SIGINT, syscall.SIGTERM) + + go func() { + for ev := range client.Events() { + switch ev.Kind { + case sdk.EventOpened: + fmt.Println("[opened]") + case sdk.EventClosed: + fmt.Println("[closed]") + return + case sdk.EventError: + fmt.Fprintf(os.Stderr, "[error] %v\n", ev.Err) + case sdk.EventMessage: + if ev.Message.Metadata != "" { + fmt.Printf("[msg #%d] %s | %s\n", ev.Message.SeqNum, ev.Message.Metadata, ev.Message.Content) + } else { + fmt.Printf("[msg #%d] %s\n", ev.Message.SeqNum, ev.Message.Content) + } + } + } + }() + + go func() { + <-sig_ch + client.Close() + os.Exit(0) + }() + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + line := scanner.Text() + send_ctx, send_cancel := context.WithTimeout(context.Background(), 5*time.Second) + err := client.Send(send_ctx, line, "") + send_cancel() + if err != nil { + fmt.Fprintf(os.Stderr, "send failed: %v\n", err) + } + } + + if err := scanner.Err(); err != nil { + fmt.Fprintf(os.Stderr, "stdin error: %v\n", err) + } + + client.Close() +} diff --git a/go.mod b/go.mod index 057812a..bd7ec54 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.4 require ( github.com/onsi/ginkgo/v2 v2.27.3 github.com/onsi/gomega v1.38.3 + nhooyr.io/websocket v1.8.17 ) require ( diff --git a/go.sum b/go.sum index 7d22274..d1aeba7 100644 --- a/go.sum +++ b/go.sum @@ -67,3 +67,5 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= +nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=