mirror of
https://github.com/saymrwulf/puncture.git
synced 2026-05-22 22:01:18 +00:00
384 lines
12 KiB
Go
384 lines
12 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"puncture-go/internal/app"
|
|
)
|
|
|
|
//go:embed static/index.html
|
|
var staticFS embed.FS
|
|
|
|
type HTTPServer struct {
|
|
state *app.AppState
|
|
remoteToken string
|
|
mux *http.ServeMux
|
|
}
|
|
|
|
func New(state *app.AppState, remoteToken string) *HTTPServer {
|
|
s := &HTTPServer{state: state, remoteToken: strings.TrimSpace(remoteToken), mux: http.NewServeMux()}
|
|
s.routes()
|
|
return s
|
|
}
|
|
|
|
func (s *HTTPServer) Handler() http.Handler { return s.mux }
|
|
|
|
func (s *HTTPServer) routes() {
|
|
s.mux.HandleFunc("/", s.handleIndex)
|
|
s.mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })
|
|
|
|
s.mux.HandleFunc("/api/state", s.handleState)
|
|
s.mux.HandleFunc("/api/live/state", s.handleState)
|
|
s.mux.HandleFunc("/api/export", s.handleExport)
|
|
s.mux.HandleFunc("/api/reset", s.handleReset)
|
|
|
|
s.mux.HandleFunc("/api/derive", s.handleDerive)
|
|
s.mux.HandleFunc("/api/puncture", s.handlePuncture)
|
|
s.mux.HandleFunc("/api/remote/puncture-provider", s.handleRemotePunctureProvider)
|
|
|
|
s.mux.HandleFunc("/api/providers/add", s.handleProviderAdd)
|
|
s.mux.HandleFunc("/api/providers/edit", s.handleProviderEdit)
|
|
s.mux.HandleFunc("/api/providers/delete", s.handleProviderDelete)
|
|
|
|
s.mux.HandleFunc("/api/assets/upload", s.handleAssetUpload)
|
|
s.mux.HandleFunc("/api/assets/encrypt", s.handleAssetEncrypt)
|
|
s.mux.HandleFunc("/api/assets/decrypt", s.handleAssetDecrypt)
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, code int, payload any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
enc := json.NewEncoder(w)
|
|
enc.SetIndent("", " ")
|
|
_ = enc.Encode(payload)
|
|
}
|
|
|
|
func decodeJSON(r *http.Request, dst any) error {
|
|
dec := json.NewDecoder(r.Body)
|
|
dec.DisallowUnknownFields()
|
|
return dec.Decode(dst)
|
|
}
|
|
|
|
func (s *HTTPServer) handleIndex(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
blob, err := staticFS.ReadFile("static/index.html")
|
|
if err != nil {
|
|
writeJSON(w, 500, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
_, _ = w.Write(blob)
|
|
}
|
|
|
|
func (s *HTTPServer) handleState(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleExport(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
blob, err := s.state.ExportStateJSON()
|
|
if err != nil {
|
|
writeJSON(w, 500, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write(blob)
|
|
}
|
|
|
|
func (s *HTTPServer) handleReset(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
if err := s.state.Reset(); err != nil {
|
|
writeJSON(w, 500, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleDerive(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
var req struct {
|
|
ProviderID int `json:"provider_id"`
|
|
FileTimeID int `json:"file_time_id"`
|
|
Purpose string `json:"purpose"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
if err := s.state.Derive(req.ProviderID, req.FileTimeID, req.Purpose); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handlePuncture(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
var req struct {
|
|
ProviderID int `json:"provider_id"`
|
|
FileTimeID int `json:"file_time_id"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
if err := s.state.Puncture(req.ProviderID, req.FileTimeID); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleRemotePunctureProvider(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
provided := r.Header.Get("X-Puncture-Token")
|
|
if !s.state.RemoteTokenValid(provided, s.remoteToken) {
|
|
writeJSON(w, 403, map[string]any{"ok": false, "error": "unauthorized"})
|
|
return
|
|
}
|
|
var req struct {
|
|
ProviderID int `json:"provider_id"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
if err := s.state.PunctureProvider(req.ProviderID); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "provider_id": req.ProviderID, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleProviderAdd(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
var req struct {
|
|
ProviderID int `json:"provider_id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
if err := s.state.AddProvider(req.ProviderID, req.Name, req.Description); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleProviderEdit(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
var req struct {
|
|
ProviderID int `json:"provider_id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
if err := s.state.EditProvider(req.ProviderID, req.Name, req.Description); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleProviderDelete(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
var req struct {
|
|
ProviderID int `json:"provider_id"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
if err := s.state.DeleteProvider(req.ProviderID); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func filesFromRequest(r *http.Request) ([]*multipart.FileHeader, string, error) {
|
|
if err := r.ParseMultipartForm(64 << 20); err != nil {
|
|
return nil, "", err
|
|
}
|
|
form := r.MultipartForm
|
|
if form == nil {
|
|
return nil, "", fmt.Errorf("multipart form missing")
|
|
}
|
|
files := form.File["files"]
|
|
if len(files) == 0 {
|
|
files = form.File["file"]
|
|
}
|
|
target := r.FormValue("target_subdir")
|
|
return files, target, nil
|
|
}
|
|
|
|
func (s *HTTPServer) handleAssetUpload(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
files, target, err := filesFromRequest(r)
|
|
if err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
saved, err := s.state.SaveUploads(files, target)
|
|
if err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error(), "state": s.state.Snapshot()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "uploaded": saved, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleAssetEncrypt(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
var req struct {
|
|
PlaintextRelpaths []string `json:"plaintext_relpaths"`
|
|
ProviderID int `json:"provider_id"`
|
|
FileTimeID int `json:"file_time_id"`
|
|
Purpose string `json:"purpose"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
saved, errs, err := s.state.Encrypt(req.PlaintextRelpaths, req.ProviderID, req.FileTimeID, req.Purpose)
|
|
if err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error(), "errors": errs, "state": s.state.Snapshot()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "saved": saved, "errors": errs, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func (s *HTTPServer) handleAssetDecrypt(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"ok": false, "error": "method not allowed"})
|
|
return
|
|
}
|
|
var req struct {
|
|
RecordIDs []int `json:"record_ids"`
|
|
}
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error()})
|
|
return
|
|
}
|
|
restored, errs, err := s.state.Decrypt(req.RecordIDs)
|
|
if err != nil {
|
|
writeJSON(w, 400, map[string]any{"ok": false, "error": err.Error(), "errors": errs, "state": s.state.Snapshot()})
|
|
return
|
|
}
|
|
writeJSON(w, 200, map[string]any{"ok": true, "restored": restored, "errors": errs, "state": s.state.Snapshot()})
|
|
}
|
|
|
|
func Run(addr, assetRoot string) error {
|
|
state, err := app.NewAppState(assetRoot)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
remoteToken := os.Getenv("PUNCTURE_REMOTE_TOKEN")
|
|
h := New(state, remoteToken)
|
|
server := &http.Server{Addr: addr, Handler: loggingMiddleware(h.Handler())}
|
|
log.Printf("puncture-go server listening on %s", addr)
|
|
return server.ListenAndServe()
|
|
}
|
|
|
|
func Start(addr, assetRoot string) (*http.Server, *app.AppState, error) {
|
|
state, err := app.NewAppState(assetRoot)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
remoteToken := os.Getenv("PUNCTURE_REMOTE_TOKEN")
|
|
h := New(state, remoteToken)
|
|
srv := &http.Server{Addr: addr, Handler: loggingMiddleware(h.Handler())}
|
|
go func() {
|
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Printf("server error: %v", err)
|
|
}
|
|
}()
|
|
deadline := time.Now().Add(4 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+addr+"/healthz", nil)
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err == nil && resp != nil && resp.StatusCode == 200 {
|
|
_ = resp.Body.Close()
|
|
cancel()
|
|
return srv, state, nil
|
|
}
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
cancel()
|
|
time.Sleep(120 * time.Millisecond)
|
|
}
|
|
return nil, nil, fmt.Errorf("server did not start in time on %s", addr)
|
|
}
|
|
|
|
func loggingMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
next.ServeHTTP(w, r)
|
|
log.Printf("%s %s (%s)", r.Method, r.URL.Path, time.Since(start).Truncate(time.Millisecond))
|
|
})
|
|
}
|
|
|
|
func ParseAddr(host string, port int) string {
|
|
if host == "" {
|
|
host = "127.0.0.1"
|
|
}
|
|
if port <= 0 {
|
|
port = 9122
|
|
}
|
|
return host + ":" + strconv.Itoa(port)
|
|
}
|