From 6a5ddd66baab2815a9bf6218fec48bcedd2503e9 Mon Sep 17 00:00:00 2001 From: dtv Date: Sat, 4 Oct 2025 23:12:03 +0300 Subject: [PATCH] created handlers for certificate manipulation in vault. Inserted device mTLS guards for public faced endpoints --- server/internal/config/config.go | 17 ++ server/internal/handlers/certs.go | 200 ++++++++++++++++++++++++ server/internal/handlers/certs_admin.go | 51 ++++++ server/internal/handlers/devices.go | 15 ++ server/internal/middleware/mtls.go | 152 ++++++++++++++++++ server/internal/models/cert.go | 24 +++ server/internal/router/router.go | 15 +- server/internal/vault/pki.go | 86 ++++++++++ 8 files changed, 555 insertions(+), 5 deletions(-) create mode 100644 server/internal/handlers/certs.go create mode 100644 server/internal/handlers/certs_admin.go create mode 100644 server/internal/middleware/mtls.go create mode 100644 server/internal/models/cert.go create mode 100644 server/internal/vault/pki.go diff --git a/server/internal/config/config.go b/server/internal/config/config.go index 15cf27a..8609fee 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -32,6 +32,7 @@ type Config struct { } MediaMTX MediaMTXConfig JWTSecret []byte + PkiIot vault.PKIClient } func Load() (*Config, error) { @@ -56,6 +57,10 @@ func Load() (*Config, error) { return nil, fmt.Errorf("VAULT_ADDR, VAULT_TOKEN, VAULT_KV_MOUNT and VAULT_KV_KEY must be set (or provide legacy VAULT_KV_PATH)") } + pki, err := vault.NewPKI(addr, token, "pki_iot", "device", 30*time.Second) + if err != nil { + return nil, err + } raw, err := vault.ReadKVv2(addr, token, mount, key) if err != nil { return nil, err @@ -193,6 +198,7 @@ func Load() (*Config, error) { TokenTTL: time.Duration(tokenTTL), } + cfg.PkiIot = *pki return cfg, nil } @@ -204,6 +210,15 @@ func LoadDev() (*Config, error) { } return v, nil } + + addr := os.Getenv("VAULT_ADDR") + token := os.Getenv("VAULT_TOKEN") + + pki, err := vault.NewPKI(addr, token, "pki_iot", "device", 30*time.Second) + if err != nil { + return nil, err + } + getBoolEnv := func(k string, def bool) bool { v := strings.ToLower(strings.TrimSpace(os.Getenv(k))) if v == "true" || v == "1" || v == "yes" { @@ -286,5 +301,7 @@ func LoadDev() (*Config, error) { PublicBaseURL: publicBase, TokenTTL: time.Duration(tokenTTL), } + + cfg.PkiIot = *pki return cfg, nil } diff --git a/server/internal/handlers/certs.go b/server/internal/handlers/certs.go new file mode 100644 index 0000000..62fba28 --- /dev/null +++ b/server/internal/handlers/certs.go @@ -0,0 +1,200 @@ +package handlers + +import ( + "context" + "crypto/x509" + "encoding/pem" + "errors" + "io" + "net/http" + "net/url" + "reflect" + "strings" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" + + "smoop-api/internal/models" + "smoop-api/internal/vault" +) + +type CertsHandler struct { + db *gorm.DB + pki *vault.PKIClient + ttl string // e.g. "720h" +} + +func NewCertsHandler(db *gorm.DB, pki *vault.PKIClient, ttl string) *CertsHandler { + return &CertsHandler{db: db, pki: pki, ttl: ttl} +} + +// ---- helpers ---------------------------------------------------------------- + +func readCSRFromRequest(c *gin.Context) (string, error) { + // Accept: multipart/form with file "csr", or JSON {"csr":"PEM..."} or text/plain body + // 1) multipart file + if f, err := c.FormFile("csr"); err == nil && f != nil { + ff, err := f.Open() + if err != nil { + return "", err + } + defer ff.Close() + b, err := io.ReadAll(ff) + return string(b), err + } + // 2) JSON + var body struct { + CSR string `json:"csr"` + } + if err := c.ShouldBindJSON(&body); err == nil && body.CSR != "" { + return body.CSR, nil + } + // 3) raw text + b, _ := io.ReadAll(c.Request.Body) + if len(b) > 0 { + return string(b), nil + } + return "", errors.New("csr required") +} + +func parseLeafPEM(pemStr string) (*x509.Certificate, error) { + block, _ := pem.Decode([]byte(pemStr)) + if block == nil { + return nil, errors.New("pem decode failed") + } + return x509.ParseCertificate(block.Bytes) +} + +func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) { + if escaped == "" { + return nil, errors.New("empty client cert header") + } + raw, err := url.QueryUnescape(escaped) + if err != nil { + // Some Nginx builds already space-escape; still try raw + raw = escaped + } + block, _ := pem.Decode([]byte(raw)) + if block == nil { + return nil, errors.New("failed to decode PEM in client cert header") + } + return x509.ParseCertificate(block.Bytes) +} + +// ---- enroll (no mTLS, just device exists) ----------------------------------- +// POST /enroll/:guid +func (h *CertsHandler) Enroll(c *gin.Context) { + guid := c.Param("guid") + // ensure device exists + var dev models.Device + if err := h.db.First(&dev, "guid = ?", guid).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "device not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "lookup failed"}) + return + } + + csr, err := readCSRFromRequest(c) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // sign in Vault + ctx, cancel := context.WithTimeout(c, 30*time.Second) + defer cancel() + sign, err := h.pki.SignCSR(ctx, csr, "urn:device:"+guid, h.ttl) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "vault sign failed"}) + return + } + + // persist cert metadata + leaf, err := parseLeafPEM(sign.Certificate) + if err == nil { + _ = h.db.Create(&models.DeviceCertificate{ + DeviceGUID: guid, + SerialHex: strings.ToUpper(leaf.SerialNumber.Text(16)), + IssuerCN: leaf.Issuer.CommonName, + SubjectDN: leaf.Subject.String(), + NotBefore: leaf.NotBefore, + NotAfter: leaf.NotAfter, + PemCert: sign.Certificate, + }).Error + } + + // response bundle + c.JSON(http.StatusOK, gin.H{ + "certificate": sign.Certificate, + "issuing_ca": sign.IssuingCA, + "ca_chain": sign.CAChain, + }) +} + +// ---- renew (mTLS; cert must not be revoked; CSR pubkey must match current) -- +// POST /renew/:guid +func (h *CertsHandler) Renew(c *gin.Context) { + guidAny, _ := c.Get("mtlsDeviceGUID") + serialAny, _ := c.Get("mtlsSerialHex") + guid := guidAny.(string) + currentSerial := serialAny.(string) + + csr, err := readCSRFromRequest(c) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Check CSR pubkey == current client cert pubkey (strong binding) + clientCert, err := parseClientCertFromHeader(c.GetHeader("X-SSL-Client-Cert")) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid client cert"}) + return + } + csrBlock, _ := pem.Decode([]byte(csr)) + if csrBlock == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "csr decode failed"}) + return + } + parsedCSR, err := x509.ParseCertificateRequest(csrBlock.Bytes) + if err != nil || parsedCSR.PublicKey == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "csr parse failed"}) + return + } + if !reflect.DeepEqual(parsedCSR.PublicKey, clientCert.PublicKey) { + c.JSON(http.StatusForbidden, gin.H{"error": "csr key does not match current certificate key"}) + return + } + + // sign + ctx, cancel := context.WithTimeout(c, 30*time.Second) + defer cancel() + sign, err := h.pki.SignCSR(ctx, csr, "urn:device:"+guid, h.ttl) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "vault sign failed"}) + return + } + + leaf, _ := parseLeafPEM(sign.Certificate) + if leaf != nil { + _ = h.db.Create(&models.DeviceCertificate{ + DeviceGUID: guid, + SerialHex: strings.ToUpper(leaf.SerialNumber.Text(16)), + IssuerCN: leaf.Issuer.CommonName, + SubjectDN: leaf.Subject.String(), + NotBefore: leaf.NotBefore, + NotAfter: leaf.NotAfter, + PemCert: sign.Certificate, + }).Error + } + + c.JSON(http.StatusOK, gin.H{ + "certificate": sign.Certificate, + "issuing_ca": sign.IssuingCA, + "ca_chain": sign.CAChain, + "old_serial": currentSerial, + }) +} diff --git a/server/internal/handlers/certs_admin.go b/server/internal/handlers/certs_admin.go new file mode 100644 index 0000000..d8747be --- /dev/null +++ b/server/internal/handlers/certs_admin.go @@ -0,0 +1,51 @@ +package handlers + +import ( + "context" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" + + "smoop-api/internal/models" + "smoop-api/internal/vault" +) + +type CertsAdminHandler struct { + db *gorm.DB + pki *vault.PKIClient +} + +func NewCertsAdminHandler(db *gorm.DB, pki *vault.PKIClient) *CertsAdminHandler { + return &CertsAdminHandler{db: db, pki: pki} +} + +type RevokeReq struct { + Serial string `json:"serial" binding:"required"` // hex + Reason string `json:"reason"` +} + +func (h *CertsAdminHandler) Revoke(c *gin.Context) { + var req RevokeReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + serial := strings.ToUpper(strings.TrimSpace(req.Serial)) + if serial == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "serial required"}) + return + } + + // store locally (instant kill) + if err := h.db.Create(&models.RevokedSerial{SerialHex: serial, Reason: req.Reason}).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "db save failed"}) + return + } + + // trigger CRL refresh in Vault (best-effort) + _ = h.pki.RebuildCRL(context.Background()) + + c.JSON(http.StatusOK, gin.H{"revoked": serial}) +} diff --git a/server/internal/handlers/devices.go b/server/internal/handlers/devices.go index 7fe3e9f..61c3c62 100644 --- a/server/internal/handlers/devices.go +++ b/server/internal/handlers/devices.go @@ -289,3 +289,18 @@ func (h *DevicesHandler) fetchUsers(ids []uint) ([]models.User, error) { } return users, nil } + +func (h *DevicesHandler) ListCertsByDevice(c *gin.Context) { + guid := c.Param("guid") + var d models.Device + if err := h.db.Where("guid = ?", guid).First(&d).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "device not found"}) + return + } + var list []models.DeviceCertificate + if err := h.db.Where("device_guid = ?", guid).Order("id desc").Find(&list).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"}) + return + } + c.JSON(http.StatusOK, gin.H{"certs": list}) +} diff --git a/server/internal/middleware/mtls.go b/server/internal/middleware/mtls.go new file mode 100644 index 0000000..a680c17 --- /dev/null +++ b/server/internal/middleware/mtls.go @@ -0,0 +1,152 @@ +package middleware + +import ( + "crypto/x509" + "encoding/hex" + "encoding/pem" + "errors" + "net/url" + "strings" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" + + "smoop-api/internal/models" +) + +// Common helpers --------------------------------------------------------------- + +func getHeader(c *gin.Context, k string) string { + return strings.TrimSpace(c.GetHeader(k)) +} + +func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) { + if escaped == "" { + return nil, errors.New("empty client cert header") + } + raw, err := url.QueryUnescape(escaped) + if err != nil { + // Some Nginx builds already space-escape; still try raw + raw = escaped + } + block, _ := pem.Decode([]byte(raw)) + if block == nil { + return nil, errors.New("failed to decode PEM in client cert header") + } + return x509.ParseCertificate(block.Bytes) +} + +func extractGuidFromSAN(cert *x509.Certificate) (string, bool) { + for _, u := range cert.URIs { + if strings.EqualFold(u.Scheme, "urn") { + // urn:device: => u.Opaque == "device:" + parts := strings.SplitN(u.Opaque, ":", 2) + if len(parts) == 2 && strings.EqualFold(parts[0], "device") { + return parts[1], true + } + } + } + return "", false +} + +func normalizeSerialHex(b []byte) string { + h := strings.ToUpper(hex.EncodeToString(b)) + // remove leading zeros normalization optional, keep as-is + return h +} + +func isSerialRevoked(db *gorm.DB, serialHex string) (bool, error) { + var cnt int64 + if err := db.Model(&models.RevokedSerial{}).Where("serial_hex = ?", serialHex).Count(&cnt).Error; err != nil { + return false, err + } + return cnt > 0, nil +} + +// Guard for /tasks/:guid and /renew/:guid ------------------------------------- +// Requires Nginx to forward: +// +// X-SSL-Client-Verify, X-SSL-Client-Serial, X-SSL-Client-Cert +func MTLSGuard(db *gorm.DB) gin.HandlerFunc { + return func(c *gin.Context) { + if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" { + c.AbortWithStatus(401) + return + } + + // Parse client cert (for GUID extraction) and serial + cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert")) + if err != nil { + c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"}) + return + } + pathGuid := strings.TrimSpace(c.Param("guid")) + certGuid, ok := extractGuidFromSAN(cert) + if !ok || certGuid == "" { + c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"}) + return + } + if pathGuid == "" || !strings.EqualFold(pathGuid, certGuid) { + c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"}) + return + } + + serialHex := normalizeSerialHex(cert.SerialNumber.Bytes()) + rev, err := isSerialRevoked(db, serialHex) + if err != nil { + c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"}) + return + } + if rev { + c.AbortWithStatus(403) + return + } + + // Stash for downstream (device + serial) + c.Set("mtlsDeviceGUID", certGuid) + c.Set("mtlsSerialHex", serialHex) + c.Next() + } +} + +// Guard for /records/upload (guid comes as form field instead of path) +func MTLSGuardUpload(db *gorm.DB) gin.HandlerFunc { + return func(c *gin.Context) { + if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" { + c.AbortWithStatus(401) + return + } + cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert")) + if err != nil { + c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"}) + return + } + certGuid, ok := extractGuidFromSAN(cert) + if !ok || certGuid == "" { + c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"}) + return + } + formGuid := strings.TrimSpace(c.PostForm("guid")) + if formGuid == "" { + // allow query fallback if you want + formGuid = strings.TrimSpace(c.Query("guid")) + } + if !strings.EqualFold(formGuid, certGuid) { + c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"}) + return + } + serialHex := normalizeSerialHex(cert.SerialNumber.Bytes()) + rev, err := isSerialRevoked(db, serialHex) + if err != nil { + c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"}) + return + } + if rev { + c.AbortWithStatus(403) + return + } + c.Set("mtlsDeviceGUID", certGuid) + c.Set("mtlsSerialHex", serialHex) + c.Next() + } +} diff --git a/server/internal/models/cert.go b/server/internal/models/cert.go new file mode 100644 index 0000000..6715fa8 --- /dev/null +++ b/server/internal/models/cert.go @@ -0,0 +1,24 @@ +package models + +import "time" + +// Link a device GUID to issued client certificates. +type DeviceCertificate struct { + ID uint `gorm:"primaryKey"` + DeviceGUID string `gorm:"index;not null"` // GUID + SerialHex string `gorm:"uniqueIndex;size:128;not null"` // hex (upper or lower; normalize) + IssuerCN string `gorm:"size:255"` + SubjectDN string `gorm:"size:1024"` + NotBefore time.Time + NotAfter time.Time + PemCert string `gorm:"type:text"` // PEM of leaf cert + CreatedAt time.Time +} + +// “Instant kill” list checked by the mTLS guard before allowing access. +type RevokedSerial struct { + ID uint `gorm:"primaryKey"` + SerialHex string `gorm:"uniqueIndex;size:128;not null"` + Reason string `gorm:"size:1024"` + CreatedAt time.Time +} diff --git a/server/internal/router/router.go b/server/internal/router/router.go index 7d8eb91..c9bfce0 100644 --- a/server/internal/router/router.go +++ b/server/internal/router/router.go @@ -33,7 +33,8 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine { trackersH := handlers.NewTrackersHandler(db) tasksH := handlers.NewTasksHandler(db) - + certsH := handlers.NewCertsHandler(db, &cfg.PkiIot, "720h") + certsAdminH := handlers.NewCertsAdminHandler(db, &cfg.PkiIot) // --- Public auth r.POST("/auth/signup", authH.SignUp) r.POST("/auth/signin", authH.SignIn) @@ -52,15 +53,17 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine { r.DELETE("/users/:id", authMW, adminOnly, usersH.Delete) r.GET("/devices", authMW, middleware.DeviceAccessFilter(), devH.List) - r.POST("/devices/create", authMW, devH.Create) + r.POST("/devices/create", authMW, adminOnly, devH.Create) r.POST("/devices/:guid/rename", authMW, devH.Rename) r.POST("/devices/:guid/add_to_user", authMW, devH.AddToUser) r.POST("/devices/:guid/set_users", authMW, adminOnly, devH.SetUsers) r.POST("/devices/:guid/remove_from_user", authMW, devH.RemoveFromUser) r.POST("/device/:guid/task", authMW, middleware.DeviceAccessFilter(), tasksH.CreateTask) r.GET("/device/:guid/tasks", authMW, middleware.DeviceAccessFilter(), tasksH.ListDeviceTasks) + r.GET("/device/:guid/certs", authMW, adminOnly, devH.ListCertsByDevice) + r.POST("/certs/revoke", authMW, adminOnly, certsAdminH.Revoke) - r.POST("/records/upload", recH.Upload) + r.POST("/records/upload", middleware.MTLSGuardUpload(db), recH.Upload) r.GET("/records", authMW, recH.List) r.GET("/records/:id/file", authMW, recH.File) @@ -86,9 +89,11 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine { r.POST("/trackers/:guid/set_users", authMW, adminOnly, trackersH.SetUsers) // --- Device Job/Task API - r.GET("/tasks/:guid", tasksH.DeviceNextTask) // heartbeat + fetch next task - r.POST("/tasks/:guid", tasksH.DevicePostResult) // device posts result + r.GET("/tasks/:guid", middleware.MTLSGuard(db), tasksH.DeviceNextTask) // heartbeat + fetch next task + r.POST("/tasks/:guid", middleware.MTLSGuard(db), tasksH.DevicePostResult) // device posts result + r.POST("/enroll/:guid", certsH.Enroll) // simple device-exists check is inside handler + r.POST("/renew/:guid", middleware.MTLSGuard(db), certsH.Renew) // sensible defaults r.MaxMultipartMemory = 64 << 20 // 64 MiB _ = time.Now() // appease linters diff --git a/server/internal/vault/pki.go b/server/internal/vault/pki.go new file mode 100644 index 0000000..6244a2b --- /dev/null +++ b/server/internal/vault/pki.go @@ -0,0 +1,86 @@ +package vault + +import ( + "context" + "encoding/json" + "fmt" + "time" + + vault "github.com/hashicorp/vault-client-go" +) + +type PKIClient struct { + client *vault.Client + mount string // e.g. "pki_iot" + role string // e.g. "device" +} + +type SignResponse struct { + Certificate string `json:"certificate"` + IssuingCA string `json:"issuing_ca"` + CAChain []string `json:"ca_chain"` + PrivateKey string `json:"private_key,omitempty"` + PrivateKeyType string `json:"private_key_type,omitempty"` + SerialNumber string `json:"serial_number"` +} + +func NewPKI(addr, token, mount, role string, timeout time.Duration) (*PKIClient, error) { + client, err := vault.New( + vault.WithAddress(addr), + vault.WithRequestTimeout(timeout), + ) + if err != nil { + return nil, fmt.Errorf("vault new: %w", err) + } + if err := client.SetToken(token); err != nil { + return nil, fmt.Errorf("set token: %w", err) + } + return &PKIClient{client: client, mount: mount, role: role}, nil +} + +// SignCSR calls: /v1//sign/ +func (p *PKIClient) SignCSR(ctx context.Context, csrPEM, uriSAN string, ttl string) (*SignResponse, error) { + if p.client == nil { + return nil, fmt.Errorf("vault client is nil") + } + path := fmt.Sprintf("/%s/sign/%s", p.mount, p.role) + req := map[string]any{ + "csr": csrPEM, + "uri_sans": uriSAN, // e.g. "urn:device:" + } + if ttl != "" { + req["ttl"] = ttl // e.g. "720h" + } + + resp, err := p.client.Write(ctx, path, req) + if err != nil { + return nil, err + } + if resp == nil || resp.Data == nil { + return nil, fmt.Errorf("vault sign: empty response") + } + + // resp.Data contains the fields we need + var out SignResponse + if err := mapToStruct(resp.Data, &out); err != nil { + return nil, err + } + return &out, nil +} + +// RebuildCRL triggers CRL regeneration (rotate). +// HCP: POST /v1//crl/rotate or read /crl to force render; we’ll call rotate when available, +// else read to ensure CRL is (re)generated. +func (p *PKIClient) RebuildCRL(ctx context.Context) error { + _, _ = p.client.Write(ctx, fmt.Sprintf("/%s/crl/rotate", p.mount), nil) // best effort + _, err := p.client.Read(ctx, fmt.Sprintf("/%s/crl", p.mount)) + return err +} + +func mapToStruct(m map[string]any, out any) error { + b, err := json.Marshal(m) + if err != nil { + return err + } + return json.Unmarshal(b, out) +}