created handlers for certificate manipulation in vault. Inserted device mTLS guards for public faced endpoints

This commit is contained in:
dtv
2025-10-04 23:12:03 +03:00
parent 35e59c4879
commit 6a5ddd66ba
8 changed files with 555 additions and 5 deletions

View File

@@ -32,6 +32,7 @@ type Config struct {
}
MediaMTX MediaMTXConfig
JWTSecret []byte
PkiIot vault.PKIClient
}
func Load() (*Config, error) {
@@ -56,6 +57,10 @@ func Load() (*Config, error) {
return nil, fmt.Errorf("VAULT_ADDR, VAULT_TOKEN, VAULT_KV_MOUNT and VAULT_KV_KEY must be set (or provide legacy VAULT_KV_PATH)")
}
pki, err := vault.NewPKI(addr, token, "pki_iot", "device", 30*time.Second)
if err != nil {
return nil, err
}
raw, err := vault.ReadKVv2(addr, token, mount, key)
if err != nil {
return nil, err
@@ -193,6 +198,7 @@ func Load() (*Config, error) {
TokenTTL: time.Duration(tokenTTL),
}
cfg.PkiIot = *pki
return cfg, nil
}
@@ -204,6 +210,15 @@ func LoadDev() (*Config, error) {
}
return v, nil
}
addr := os.Getenv("VAULT_ADDR")
token := os.Getenv("VAULT_TOKEN")
pki, err := vault.NewPKI(addr, token, "pki_iot", "device", 30*time.Second)
if err != nil {
return nil, err
}
getBoolEnv := func(k string, def bool) bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(k)))
if v == "true" || v == "1" || v == "yes" {
@@ -286,5 +301,7 @@ func LoadDev() (*Config, error) {
PublicBaseURL: publicBase,
TokenTTL: time.Duration(tokenTTL),
}
cfg.PkiIot = *pki
return cfg, nil
}

View File

@@ -0,0 +1,200 @@
package handlers
import (
"context"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"smoop-api/internal/models"
"smoop-api/internal/vault"
)
type CertsHandler struct {
db *gorm.DB
pki *vault.PKIClient
ttl string // e.g. "720h"
}
func NewCertsHandler(db *gorm.DB, pki *vault.PKIClient, ttl string) *CertsHandler {
return &CertsHandler{db: db, pki: pki, ttl: ttl}
}
// ---- helpers ----------------------------------------------------------------
func readCSRFromRequest(c *gin.Context) (string, error) {
// Accept: multipart/form with file "csr", or JSON {"csr":"PEM..."} or text/plain body
// 1) multipart file
if f, err := c.FormFile("csr"); err == nil && f != nil {
ff, err := f.Open()
if err != nil {
return "", err
}
defer ff.Close()
b, err := io.ReadAll(ff)
return string(b), err
}
// 2) JSON
var body struct {
CSR string `json:"csr"`
}
if err := c.ShouldBindJSON(&body); err == nil && body.CSR != "" {
return body.CSR, nil
}
// 3) raw text
b, _ := io.ReadAll(c.Request.Body)
if len(b) > 0 {
return string(b), nil
}
return "", errors.New("csr required")
}
func parseLeafPEM(pemStr string) (*x509.Certificate, error) {
block, _ := pem.Decode([]byte(pemStr))
if block == nil {
return nil, errors.New("pem decode failed")
}
return x509.ParseCertificate(block.Bytes)
}
func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) {
if escaped == "" {
return nil, errors.New("empty client cert header")
}
raw, err := url.QueryUnescape(escaped)
if err != nil {
// Some Nginx builds already space-escape; still try raw
raw = escaped
}
block, _ := pem.Decode([]byte(raw))
if block == nil {
return nil, errors.New("failed to decode PEM in client cert header")
}
return x509.ParseCertificate(block.Bytes)
}
// ---- enroll (no mTLS, just device exists) -----------------------------------
// POST /enroll/:guid
func (h *CertsHandler) Enroll(c *gin.Context) {
guid := c.Param("guid")
// ensure device exists
var dev models.Device
if err := h.db.First(&dev, "guid = ?", guid).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "device not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "lookup failed"})
return
}
csr, err := readCSRFromRequest(c)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// sign in Vault
ctx, cancel := context.WithTimeout(c, 30*time.Second)
defer cancel()
sign, err := h.pki.SignCSR(ctx, csr, "urn:device:"+guid, h.ttl)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "vault sign failed"})
return
}
// persist cert metadata
leaf, err := parseLeafPEM(sign.Certificate)
if err == nil {
_ = h.db.Create(&models.DeviceCertificate{
DeviceGUID: guid,
SerialHex: strings.ToUpper(leaf.SerialNumber.Text(16)),
IssuerCN: leaf.Issuer.CommonName,
SubjectDN: leaf.Subject.String(),
NotBefore: leaf.NotBefore,
NotAfter: leaf.NotAfter,
PemCert: sign.Certificate,
}).Error
}
// response bundle
c.JSON(http.StatusOK, gin.H{
"certificate": sign.Certificate,
"issuing_ca": sign.IssuingCA,
"ca_chain": sign.CAChain,
})
}
// ---- renew (mTLS; cert must not be revoked; CSR pubkey must match current) --
// POST /renew/:guid
func (h *CertsHandler) Renew(c *gin.Context) {
guidAny, _ := c.Get("mtlsDeviceGUID")
serialAny, _ := c.Get("mtlsSerialHex")
guid := guidAny.(string)
currentSerial := serialAny.(string)
csr, err := readCSRFromRequest(c)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Check CSR pubkey == current client cert pubkey (strong binding)
clientCert, err := parseClientCertFromHeader(c.GetHeader("X-SSL-Client-Cert"))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid client cert"})
return
}
csrBlock, _ := pem.Decode([]byte(csr))
if csrBlock == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "csr decode failed"})
return
}
parsedCSR, err := x509.ParseCertificateRequest(csrBlock.Bytes)
if err != nil || parsedCSR.PublicKey == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "csr parse failed"})
return
}
if !reflect.DeepEqual(parsedCSR.PublicKey, clientCert.PublicKey) {
c.JSON(http.StatusForbidden, gin.H{"error": "csr key does not match current certificate key"})
return
}
// sign
ctx, cancel := context.WithTimeout(c, 30*time.Second)
defer cancel()
sign, err := h.pki.SignCSR(ctx, csr, "urn:device:"+guid, h.ttl)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "vault sign failed"})
return
}
leaf, _ := parseLeafPEM(sign.Certificate)
if leaf != nil {
_ = h.db.Create(&models.DeviceCertificate{
DeviceGUID: guid,
SerialHex: strings.ToUpper(leaf.SerialNumber.Text(16)),
IssuerCN: leaf.Issuer.CommonName,
SubjectDN: leaf.Subject.String(),
NotBefore: leaf.NotBefore,
NotAfter: leaf.NotAfter,
PemCert: sign.Certificate,
}).Error
}
c.JSON(http.StatusOK, gin.H{
"certificate": sign.Certificate,
"issuing_ca": sign.IssuingCA,
"ca_chain": sign.CAChain,
"old_serial": currentSerial,
})
}

View File

@@ -0,0 +1,51 @@
package handlers
import (
"context"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"smoop-api/internal/models"
"smoop-api/internal/vault"
)
type CertsAdminHandler struct {
db *gorm.DB
pki *vault.PKIClient
}
func NewCertsAdminHandler(db *gorm.DB, pki *vault.PKIClient) *CertsAdminHandler {
return &CertsAdminHandler{db: db, pki: pki}
}
type RevokeReq struct {
Serial string `json:"serial" binding:"required"` // hex
Reason string `json:"reason"`
}
func (h *CertsAdminHandler) Revoke(c *gin.Context) {
var req RevokeReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
serial := strings.ToUpper(strings.TrimSpace(req.Serial))
if serial == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "serial required"})
return
}
// store locally (instant kill)
if err := h.db.Create(&models.RevokedSerial{SerialHex: serial, Reason: req.Reason}).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "db save failed"})
return
}
// trigger CRL refresh in Vault (best-effort)
_ = h.pki.RebuildCRL(context.Background())
c.JSON(http.StatusOK, gin.H{"revoked": serial})
}

View File

@@ -289,3 +289,18 @@ func (h *DevicesHandler) fetchUsers(ids []uint) ([]models.User, error) {
}
return users, nil
}
func (h *DevicesHandler) ListCertsByDevice(c *gin.Context) {
guid := c.Param("guid")
var d models.Device
if err := h.db.Where("guid = ?", guid).First(&d).Error; err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "device not found"})
return
}
var list []models.DeviceCertificate
if err := h.db.Where("device_guid = ?", guid).Order("id desc").Find(&list).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
return
}
c.JSON(http.StatusOK, gin.H{"certs": list})
}

View File

@@ -0,0 +1,152 @@
package middleware
import (
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"smoop-api/internal/models"
)
// Common helpers ---------------------------------------------------------------
func getHeader(c *gin.Context, k string) string {
return strings.TrimSpace(c.GetHeader(k))
}
func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) {
if escaped == "" {
return nil, errors.New("empty client cert header")
}
raw, err := url.QueryUnescape(escaped)
if err != nil {
// Some Nginx builds already space-escape; still try raw
raw = escaped
}
block, _ := pem.Decode([]byte(raw))
if block == nil {
return nil, errors.New("failed to decode PEM in client cert header")
}
return x509.ParseCertificate(block.Bytes)
}
func extractGuidFromSAN(cert *x509.Certificate) (string, bool) {
for _, u := range cert.URIs {
if strings.EqualFold(u.Scheme, "urn") {
// urn:device:<GUID> => u.Opaque == "device:<GUID>"
parts := strings.SplitN(u.Opaque, ":", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "device") {
return parts[1], true
}
}
}
return "", false
}
func normalizeSerialHex(b []byte) string {
h := strings.ToUpper(hex.EncodeToString(b))
// remove leading zeros normalization optional, keep as-is
return h
}
func isSerialRevoked(db *gorm.DB, serialHex string) (bool, error) {
var cnt int64
if err := db.Model(&models.RevokedSerial{}).Where("serial_hex = ?", serialHex).Count(&cnt).Error; err != nil {
return false, err
}
return cnt > 0, nil
}
// Guard for /tasks/:guid and /renew/:guid -------------------------------------
// Requires Nginx to forward:
//
// X-SSL-Client-Verify, X-SSL-Client-Serial, X-SSL-Client-Cert
func MTLSGuard(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
c.AbortWithStatus(401)
return
}
// Parse client cert (for GUID extraction) and serial
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
if err != nil {
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
return
}
pathGuid := strings.TrimSpace(c.Param("guid"))
certGuid, ok := extractGuidFromSAN(cert)
if !ok || certGuid == "" {
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
return
}
if pathGuid == "" || !strings.EqualFold(pathGuid, certGuid) {
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
return
}
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
rev, err := isSerialRevoked(db, serialHex)
if err != nil {
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
return
}
if rev {
c.AbortWithStatus(403)
return
}
// Stash for downstream (device + serial)
c.Set("mtlsDeviceGUID", certGuid)
c.Set("mtlsSerialHex", serialHex)
c.Next()
}
}
// Guard for /records/upload (guid comes as form field instead of path)
func MTLSGuardUpload(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
c.AbortWithStatus(401)
return
}
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
if err != nil {
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
return
}
certGuid, ok := extractGuidFromSAN(cert)
if !ok || certGuid == "" {
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
return
}
formGuid := strings.TrimSpace(c.PostForm("guid"))
if formGuid == "" {
// allow query fallback if you want
formGuid = strings.TrimSpace(c.Query("guid"))
}
if !strings.EqualFold(formGuid, certGuid) {
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
return
}
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
rev, err := isSerialRevoked(db, serialHex)
if err != nil {
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
return
}
if rev {
c.AbortWithStatus(403)
return
}
c.Set("mtlsDeviceGUID", certGuid)
c.Set("mtlsSerialHex", serialHex)
c.Next()
}
}

View File

@@ -0,0 +1,24 @@
package models
import "time"
// Link a device GUID to issued client certificates.
type DeviceCertificate struct {
ID uint `gorm:"primaryKey"`
DeviceGUID string `gorm:"index;not null"` // GUID
SerialHex string `gorm:"uniqueIndex;size:128;not null"` // hex (upper or lower; normalize)
IssuerCN string `gorm:"size:255"`
SubjectDN string `gorm:"size:1024"`
NotBefore time.Time
NotAfter time.Time
PemCert string `gorm:"type:text"` // PEM of leaf cert
CreatedAt time.Time
}
// “Instant kill” list checked by the mTLS guard before allowing access.
type RevokedSerial struct {
ID uint `gorm:"primaryKey"`
SerialHex string `gorm:"uniqueIndex;size:128;not null"`
Reason string `gorm:"size:1024"`
CreatedAt time.Time
}

View File

@@ -33,7 +33,8 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine {
trackersH := handlers.NewTrackersHandler(db)
tasksH := handlers.NewTasksHandler(db)
certsH := handlers.NewCertsHandler(db, &cfg.PkiIot, "720h")
certsAdminH := handlers.NewCertsAdminHandler(db, &cfg.PkiIot)
// --- Public auth
r.POST("/auth/signup", authH.SignUp)
r.POST("/auth/signin", authH.SignIn)
@@ -52,15 +53,17 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine {
r.DELETE("/users/:id", authMW, adminOnly, usersH.Delete)
r.GET("/devices", authMW, middleware.DeviceAccessFilter(), devH.List)
r.POST("/devices/create", authMW, devH.Create)
r.POST("/devices/create", authMW, adminOnly, devH.Create)
r.POST("/devices/:guid/rename", authMW, devH.Rename)
r.POST("/devices/:guid/add_to_user", authMW, devH.AddToUser)
r.POST("/devices/:guid/set_users", authMW, adminOnly, devH.SetUsers)
r.POST("/devices/:guid/remove_from_user", authMW, devH.RemoveFromUser)
r.POST("/device/:guid/task", authMW, middleware.DeviceAccessFilter(), tasksH.CreateTask)
r.GET("/device/:guid/tasks", authMW, middleware.DeviceAccessFilter(), tasksH.ListDeviceTasks)
r.GET("/device/:guid/certs", authMW, adminOnly, devH.ListCertsByDevice)
r.POST("/certs/revoke", authMW, adminOnly, certsAdminH.Revoke)
r.POST("/records/upload", recH.Upload)
r.POST("/records/upload", middleware.MTLSGuardUpload(db), recH.Upload)
r.GET("/records", authMW, recH.List)
r.GET("/records/:id/file", authMW, recH.File)
@@ -86,9 +89,11 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine {
r.POST("/trackers/:guid/set_users", authMW, adminOnly, trackersH.SetUsers)
// --- Device Job/Task API
r.GET("/tasks/:guid", tasksH.DeviceNextTask) // heartbeat + fetch next task
r.POST("/tasks/:guid", tasksH.DevicePostResult) // device posts result
r.GET("/tasks/:guid", middleware.MTLSGuard(db), tasksH.DeviceNextTask) // heartbeat + fetch next task
r.POST("/tasks/:guid", middleware.MTLSGuard(db), tasksH.DevicePostResult) // device posts result
r.POST("/enroll/:guid", certsH.Enroll) // simple device-exists check is inside handler
r.POST("/renew/:guid", middleware.MTLSGuard(db), certsH.Renew)
// sensible defaults
r.MaxMultipartMemory = 64 << 20 // 64 MiB
_ = time.Now() // appease linters

View File

@@ -0,0 +1,86 @@
package vault
import (
"context"
"encoding/json"
"fmt"
"time"
vault "github.com/hashicorp/vault-client-go"
)
type PKIClient struct {
client *vault.Client
mount string // e.g. "pki_iot"
role string // e.g. "device"
}
type SignResponse struct {
Certificate string `json:"certificate"`
IssuingCA string `json:"issuing_ca"`
CAChain []string `json:"ca_chain"`
PrivateKey string `json:"private_key,omitempty"`
PrivateKeyType string `json:"private_key_type,omitempty"`
SerialNumber string `json:"serial_number"`
}
func NewPKI(addr, token, mount, role string, timeout time.Duration) (*PKIClient, error) {
client, err := vault.New(
vault.WithAddress(addr),
vault.WithRequestTimeout(timeout),
)
if err != nil {
return nil, fmt.Errorf("vault new: %w", err)
}
if err := client.SetToken(token); err != nil {
return nil, fmt.Errorf("set token: %w", err)
}
return &PKIClient{client: client, mount: mount, role: role}, nil
}
// SignCSR calls: /v1/<mount>/sign/<role>
func (p *PKIClient) SignCSR(ctx context.Context, csrPEM, uriSAN string, ttl string) (*SignResponse, error) {
if p.client == nil {
return nil, fmt.Errorf("vault client is nil")
}
path := fmt.Sprintf("/%s/sign/%s", p.mount, p.role)
req := map[string]any{
"csr": csrPEM,
"uri_sans": uriSAN, // e.g. "urn:device:<GUID>"
}
if ttl != "" {
req["ttl"] = ttl // e.g. "720h"
}
resp, err := p.client.Write(ctx, path, req)
if err != nil {
return nil, err
}
if resp == nil || resp.Data == nil {
return nil, fmt.Errorf("vault sign: empty response")
}
// resp.Data contains the fields we need
var out SignResponse
if err := mapToStruct(resp.Data, &out); err != nil {
return nil, err
}
return &out, nil
}
// RebuildCRL triggers CRL regeneration (rotate).
// HCP: POST /v1/<mount>/crl/rotate or read /crl to force render; well call rotate when available,
// else read to ensure CRL is (re)generated.
func (p *PKIClient) RebuildCRL(ctx context.Context) error {
_, _ = p.client.Write(ctx, fmt.Sprintf("/%s/crl/rotate", p.mount), nil) // best effort
_, err := p.client.Read(ctx, fmt.Sprintf("/%s/crl", p.mount))
return err
}
func mapToStruct(m map[string]any, out any) error {
b, err := json.Marshal(m)
if err != nil {
return err
}
return json.Unmarshal(b, out)
}