119 lines
2.6 KiB
Go
119 lines
2.6 KiB
Go
|
|
package ownwire_sdk
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/base64"
|
||
|
|
"fmt"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
type SessionState struct {
|
||
|
|
SessionId string
|
||
|
|
ClientPubKeyB64 string
|
||
|
|
|
||
|
|
ServerPubKeyB64 string
|
||
|
|
SaltB64 string
|
||
|
|
|
||
|
|
SharedKey [32]byte
|
||
|
|
SessionIdBytes [16]byte
|
||
|
|
|
||
|
|
SeqOut uint64
|
||
|
|
SeqInMax uint64
|
||
|
|
}
|
||
|
|
|
||
|
|
type Handshaker struct {
|
||
|
|
Timeout time.Duration
|
||
|
|
GenClientKeyF func() (Keypair, error)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (h Handshaker) EnsureDefaults() Handshaker {
|
||
|
|
if h.Timeout == 0 {
|
||
|
|
h.Timeout = 10 * time.Second
|
||
|
|
}
|
||
|
|
if h.GenClientKeyF == nil {
|
||
|
|
h.GenClientKeyF = GenClientKey
|
||
|
|
}
|
||
|
|
return h
|
||
|
|
}
|
||
|
|
|
||
|
|
func (h Handshaker) Run(ctx context.Context, conn Conn, resume_session_id string) (SessionState, []string, error) {
|
||
|
|
h = h.EnsureDefaults()
|
||
|
|
|
||
|
|
timeout_ctx, cancel := context.WithTimeout(ctx, h.Timeout)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
keypair, err := h.GenClientKeyF()
|
||
|
|
if err != nil {
|
||
|
|
return SessionState{}, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
client_pub_b64 := base64.StdEncoding.EncodeToString(keypair.ClientPub)
|
||
|
|
|
||
|
|
if resume_session_id != "" {
|
||
|
|
open_cmd := fmt.Sprintf("/open:%s:%s", resume_session_id, client_pub_b64)
|
||
|
|
if err := conn.WriteText(timeout_ctx, open_cmd); err != nil {
|
||
|
|
return SessionState{}, nil, err
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
create_cmd := fmt.Sprintf("/create:%s", client_pub_b64)
|
||
|
|
if err := conn.WriteText(timeout_ctx, create_cmd); err != nil {
|
||
|
|
return SessionState{}, nil, err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pending := make([]string, 0, 8)
|
||
|
|
|
||
|
|
for {
|
||
|
|
line, err := conn.ReadText(timeout_ctx)
|
||
|
|
if err != nil {
|
||
|
|
return SessionState{}, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(line) >= 9 && line[:9] == "/session:" {
|
||
|
|
parsed, err := ParseSessionInit(line)
|
||
|
|
if err != nil {
|
||
|
|
return SessionState{}, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
server_pub_raw, err := base64.StdEncoding.DecodeString(parsed.ServerPubKeyB64)
|
||
|
|
if err != nil {
|
||
|
|
return SessionState{}, nil, fmt.Errorf("invalid server pubkey base64: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
salt_raw, err := base64.StdEncoding.DecodeString(parsed.SaltB64)
|
||
|
|
if err != nil {
|
||
|
|
return SessionState{}, nil, fmt.Errorf("invalid salt base64: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
shared_key, err := DeriveSharedKey(parsed.SessionId, keypair.ClientPriv, server_pub_raw, salt_raw)
|
||
|
|
if err != nil {
|
||
|
|
return SessionState{}, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
session_id_bytes, err := ParseUUIDBytes(parsed.SessionId)
|
||
|
|
if err != nil {
|
||
|
|
return SessionState{}, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
out := SessionState{
|
||
|
|
SessionId: parsed.SessionId,
|
||
|
|
ClientPubKeyB64: client_pub_b64,
|
||
|
|
|
||
|
|
ServerPubKeyB64: parsed.ServerPubKeyB64,
|
||
|
|
SaltB64: parsed.SaltB64,
|
||
|
|
|
||
|
|
SharedKey: shared_key,
|
||
|
|
SessionIdBytes: session_id_bytes,
|
||
|
|
|
||
|
|
SeqOut: parsed.SeqInMax,
|
||
|
|
SeqInMax: parsed.SeqOut,
|
||
|
|
}
|
||
|
|
|
||
|
|
return out, pending, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
pending = append(pending, line)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|