created handlers for certificate manipulation in vault. Inserted device mTLS guards for public faced endpoints
This commit is contained in:
@@ -32,6 +32,7 @@ type Config struct {
|
||||
}
|
||||
MediaMTX MediaMTXConfig
|
||||
JWTSecret []byte
|
||||
PkiIot vault.PKIClient
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
@@ -56,6 +57,10 @@ func Load() (*Config, error) {
|
||||
return nil, fmt.Errorf("VAULT_ADDR, VAULT_TOKEN, VAULT_KV_MOUNT and VAULT_KV_KEY must be set (or provide legacy VAULT_KV_PATH)")
|
||||
}
|
||||
|
||||
pki, err := vault.NewPKI(addr, token, "pki_iot", "device", 30*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, err := vault.ReadKVv2(addr, token, mount, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -193,6 +198,7 @@ func Load() (*Config, error) {
|
||||
TokenTTL: time.Duration(tokenTTL),
|
||||
}
|
||||
|
||||
cfg.PkiIot = *pki
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -204,6 +210,15 @@ func LoadDev() (*Config, error) {
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
addr := os.Getenv("VAULT_ADDR")
|
||||
token := os.Getenv("VAULT_TOKEN")
|
||||
|
||||
pki, err := vault.NewPKI(addr, token, "pki_iot", "device", 30*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getBoolEnv := func(k string, def bool) bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(k)))
|
||||
if v == "true" || v == "1" || v == "yes" {
|
||||
@@ -286,5 +301,7 @@ func LoadDev() (*Config, error) {
|
||||
PublicBaseURL: publicBase,
|
||||
TokenTTL: time.Duration(tokenTTL),
|
||||
}
|
||||
|
||||
cfg.PkiIot = *pki
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
200
server/internal/handlers/certs.go
Normal file
200
server/internal/handlers/certs.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"smoop-api/internal/models"
|
||||
"smoop-api/internal/vault"
|
||||
)
|
||||
|
||||
type CertsHandler struct {
|
||||
db *gorm.DB
|
||||
pki *vault.PKIClient
|
||||
ttl string // e.g. "720h"
|
||||
}
|
||||
|
||||
func NewCertsHandler(db *gorm.DB, pki *vault.PKIClient, ttl string) *CertsHandler {
|
||||
return &CertsHandler{db: db, pki: pki, ttl: ttl}
|
||||
}
|
||||
|
||||
// ---- helpers ----------------------------------------------------------------
|
||||
|
||||
func readCSRFromRequest(c *gin.Context) (string, error) {
|
||||
// Accept: multipart/form with file "csr", or JSON {"csr":"PEM..."} or text/plain body
|
||||
// 1) multipart file
|
||||
if f, err := c.FormFile("csr"); err == nil && f != nil {
|
||||
ff, err := f.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer ff.Close()
|
||||
b, err := io.ReadAll(ff)
|
||||
return string(b), err
|
||||
}
|
||||
// 2) JSON
|
||||
var body struct {
|
||||
CSR string `json:"csr"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err == nil && body.CSR != "" {
|
||||
return body.CSR, nil
|
||||
}
|
||||
// 3) raw text
|
||||
b, _ := io.ReadAll(c.Request.Body)
|
||||
if len(b) > 0 {
|
||||
return string(b), nil
|
||||
}
|
||||
return "", errors.New("csr required")
|
||||
}
|
||||
|
||||
func parseLeafPEM(pemStr string) (*x509.Certificate, error) {
|
||||
block, _ := pem.Decode([]byte(pemStr))
|
||||
if block == nil {
|
||||
return nil, errors.New("pem decode failed")
|
||||
}
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
|
||||
func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) {
|
||||
if escaped == "" {
|
||||
return nil, errors.New("empty client cert header")
|
||||
}
|
||||
raw, err := url.QueryUnescape(escaped)
|
||||
if err != nil {
|
||||
// Some Nginx builds already space-escape; still try raw
|
||||
raw = escaped
|
||||
}
|
||||
block, _ := pem.Decode([]byte(raw))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM in client cert header")
|
||||
}
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
|
||||
// ---- enroll (no mTLS, just device exists) -----------------------------------
|
||||
// POST /enroll/:guid
|
||||
func (h *CertsHandler) Enroll(c *gin.Context) {
|
||||
guid := c.Param("guid")
|
||||
// ensure device exists
|
||||
var dev models.Device
|
||||
if err := h.db.First(&dev, "guid = ?", guid).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "device not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "lookup failed"})
|
||||
return
|
||||
}
|
||||
|
||||
csr, err := readCSRFromRequest(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// sign in Vault
|
||||
ctx, cancel := context.WithTimeout(c, 30*time.Second)
|
||||
defer cancel()
|
||||
sign, err := h.pki.SignCSR(ctx, csr, "urn:device:"+guid, h.ttl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "vault sign failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// persist cert metadata
|
||||
leaf, err := parseLeafPEM(sign.Certificate)
|
||||
if err == nil {
|
||||
_ = h.db.Create(&models.DeviceCertificate{
|
||||
DeviceGUID: guid,
|
||||
SerialHex: strings.ToUpper(leaf.SerialNumber.Text(16)),
|
||||
IssuerCN: leaf.Issuer.CommonName,
|
||||
SubjectDN: leaf.Subject.String(),
|
||||
NotBefore: leaf.NotBefore,
|
||||
NotAfter: leaf.NotAfter,
|
||||
PemCert: sign.Certificate,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// response bundle
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"certificate": sign.Certificate,
|
||||
"issuing_ca": sign.IssuingCA,
|
||||
"ca_chain": sign.CAChain,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- renew (mTLS; cert must not be revoked; CSR pubkey must match current) --
|
||||
// POST /renew/:guid
|
||||
func (h *CertsHandler) Renew(c *gin.Context) {
|
||||
guidAny, _ := c.Get("mtlsDeviceGUID")
|
||||
serialAny, _ := c.Get("mtlsSerialHex")
|
||||
guid := guidAny.(string)
|
||||
currentSerial := serialAny.(string)
|
||||
|
||||
csr, err := readCSRFromRequest(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Check CSR pubkey == current client cert pubkey (strong binding)
|
||||
clientCert, err := parseClientCertFromHeader(c.GetHeader("X-SSL-Client-Cert"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid client cert"})
|
||||
return
|
||||
}
|
||||
csrBlock, _ := pem.Decode([]byte(csr))
|
||||
if csrBlock == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "csr decode failed"})
|
||||
return
|
||||
}
|
||||
parsedCSR, err := x509.ParseCertificateRequest(csrBlock.Bytes)
|
||||
if err != nil || parsedCSR.PublicKey == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "csr parse failed"})
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(parsedCSR.PublicKey, clientCert.PublicKey) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "csr key does not match current certificate key"})
|
||||
return
|
||||
}
|
||||
|
||||
// sign
|
||||
ctx, cancel := context.WithTimeout(c, 30*time.Second)
|
||||
defer cancel()
|
||||
sign, err := h.pki.SignCSR(ctx, csr, "urn:device:"+guid, h.ttl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "vault sign failed"})
|
||||
return
|
||||
}
|
||||
|
||||
leaf, _ := parseLeafPEM(sign.Certificate)
|
||||
if leaf != nil {
|
||||
_ = h.db.Create(&models.DeviceCertificate{
|
||||
DeviceGUID: guid,
|
||||
SerialHex: strings.ToUpper(leaf.SerialNumber.Text(16)),
|
||||
IssuerCN: leaf.Issuer.CommonName,
|
||||
SubjectDN: leaf.Subject.String(),
|
||||
NotBefore: leaf.NotBefore,
|
||||
NotAfter: leaf.NotAfter,
|
||||
PemCert: sign.Certificate,
|
||||
}).Error
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"certificate": sign.Certificate,
|
||||
"issuing_ca": sign.IssuingCA,
|
||||
"ca_chain": sign.CAChain,
|
||||
"old_serial": currentSerial,
|
||||
})
|
||||
}
|
||||
51
server/internal/handlers/certs_admin.go
Normal file
51
server/internal/handlers/certs_admin.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"smoop-api/internal/models"
|
||||
"smoop-api/internal/vault"
|
||||
)
|
||||
|
||||
type CertsAdminHandler struct {
|
||||
db *gorm.DB
|
||||
pki *vault.PKIClient
|
||||
}
|
||||
|
||||
func NewCertsAdminHandler(db *gorm.DB, pki *vault.PKIClient) *CertsAdminHandler {
|
||||
return &CertsAdminHandler{db: db, pki: pki}
|
||||
}
|
||||
|
||||
type RevokeReq struct {
|
||||
Serial string `json:"serial" binding:"required"` // hex
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
func (h *CertsAdminHandler) Revoke(c *gin.Context) {
|
||||
var req RevokeReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
serial := strings.ToUpper(strings.TrimSpace(req.Serial))
|
||||
if serial == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "serial required"})
|
||||
return
|
||||
}
|
||||
|
||||
// store locally (instant kill)
|
||||
if err := h.db.Create(&models.RevokedSerial{SerialHex: serial, Reason: req.Reason}).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "db save failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// trigger CRL refresh in Vault (best-effort)
|
||||
_ = h.pki.RebuildCRL(context.Background())
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"revoked": serial})
|
||||
}
|
||||
@@ -289,3 +289,18 @@ func (h *DevicesHandler) fetchUsers(ids []uint) ([]models.User, error) {
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (h *DevicesHandler) ListCertsByDevice(c *gin.Context) {
|
||||
guid := c.Param("guid")
|
||||
var d models.Device
|
||||
if err := h.db.Where("guid = ?", guid).First(&d).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "device not found"})
|
||||
return
|
||||
}
|
||||
var list []models.DeviceCertificate
|
||||
if err := h.db.Where("device_guid = ?", guid).Order("id desc").Find(&list).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"certs": list})
|
||||
}
|
||||
|
||||
152
server/internal/middleware/mtls.go
Normal file
152
server/internal/middleware/mtls.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"smoop-api/internal/models"
|
||||
)
|
||||
|
||||
// Common helpers ---------------------------------------------------------------
|
||||
|
||||
func getHeader(c *gin.Context, k string) string {
|
||||
return strings.TrimSpace(c.GetHeader(k))
|
||||
}
|
||||
|
||||
func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) {
|
||||
if escaped == "" {
|
||||
return nil, errors.New("empty client cert header")
|
||||
}
|
||||
raw, err := url.QueryUnescape(escaped)
|
||||
if err != nil {
|
||||
// Some Nginx builds already space-escape; still try raw
|
||||
raw = escaped
|
||||
}
|
||||
block, _ := pem.Decode([]byte(raw))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM in client cert header")
|
||||
}
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
|
||||
func extractGuidFromSAN(cert *x509.Certificate) (string, bool) {
|
||||
for _, u := range cert.URIs {
|
||||
if strings.EqualFold(u.Scheme, "urn") {
|
||||
// urn:device:<GUID> => u.Opaque == "device:<GUID>"
|
||||
parts := strings.SplitN(u.Opaque, ":", 2)
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "device") {
|
||||
return parts[1], true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func normalizeSerialHex(b []byte) string {
|
||||
h := strings.ToUpper(hex.EncodeToString(b))
|
||||
// remove leading zeros normalization optional, keep as-is
|
||||
return h
|
||||
}
|
||||
|
||||
func isSerialRevoked(db *gorm.DB, serialHex string) (bool, error) {
|
||||
var cnt int64
|
||||
if err := db.Model(&models.RevokedSerial{}).Where("serial_hex = ?", serialHex).Count(&cnt).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cnt > 0, nil
|
||||
}
|
||||
|
||||
// Guard for /tasks/:guid and /renew/:guid -------------------------------------
|
||||
// Requires Nginx to forward:
|
||||
//
|
||||
// X-SSL-Client-Verify, X-SSL-Client-Serial, X-SSL-Client-Cert
|
||||
func MTLSGuard(db *gorm.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
|
||||
c.AbortWithStatus(401)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse client cert (for GUID extraction) and serial
|
||||
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
|
||||
return
|
||||
}
|
||||
pathGuid := strings.TrimSpace(c.Param("guid"))
|
||||
certGuid, ok := extractGuidFromSAN(cert)
|
||||
if !ok || certGuid == "" {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
|
||||
return
|
||||
}
|
||||
if pathGuid == "" || !strings.EqualFold(pathGuid, certGuid) {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
|
||||
rev, err := isSerialRevoked(db, serialHex)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
|
||||
return
|
||||
}
|
||||
if rev {
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
|
||||
// Stash for downstream (device + serial)
|
||||
c.Set("mtlsDeviceGUID", certGuid)
|
||||
c.Set("mtlsSerialHex", serialHex)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Guard for /records/upload (guid comes as form field instead of path)
|
||||
func MTLSGuardUpload(db *gorm.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
|
||||
c.AbortWithStatus(401)
|
||||
return
|
||||
}
|
||||
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
|
||||
return
|
||||
}
|
||||
certGuid, ok := extractGuidFromSAN(cert)
|
||||
if !ok || certGuid == "" {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
|
||||
return
|
||||
}
|
||||
formGuid := strings.TrimSpace(c.PostForm("guid"))
|
||||
if formGuid == "" {
|
||||
// allow query fallback if you want
|
||||
formGuid = strings.TrimSpace(c.Query("guid"))
|
||||
}
|
||||
if !strings.EqualFold(formGuid, certGuid) {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
|
||||
return
|
||||
}
|
||||
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
|
||||
rev, err := isSerialRevoked(db, serialHex)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
|
||||
return
|
||||
}
|
||||
if rev {
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
c.Set("mtlsDeviceGUID", certGuid)
|
||||
c.Set("mtlsSerialHex", serialHex)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
24
server/internal/models/cert.go
Normal file
24
server/internal/models/cert.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// Link a device GUID to issued client certificates.
|
||||
type DeviceCertificate struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
DeviceGUID string `gorm:"index;not null"` // GUID
|
||||
SerialHex string `gorm:"uniqueIndex;size:128;not null"` // hex (upper or lower; normalize)
|
||||
IssuerCN string `gorm:"size:255"`
|
||||
SubjectDN string `gorm:"size:1024"`
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
PemCert string `gorm:"type:text"` // PEM of leaf cert
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// “Instant kill” list checked by the mTLS guard before allowing access.
|
||||
type RevokedSerial struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
SerialHex string `gorm:"uniqueIndex;size:128;not null"`
|
||||
Reason string `gorm:"size:1024"`
|
||||
CreatedAt time.Time
|
||||
}
|
||||
@@ -33,7 +33,8 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine {
|
||||
trackersH := handlers.NewTrackersHandler(db)
|
||||
|
||||
tasksH := handlers.NewTasksHandler(db)
|
||||
|
||||
certsH := handlers.NewCertsHandler(db, &cfg.PkiIot, "720h")
|
||||
certsAdminH := handlers.NewCertsAdminHandler(db, &cfg.PkiIot)
|
||||
// --- Public auth
|
||||
r.POST("/auth/signup", authH.SignUp)
|
||||
r.POST("/auth/signin", authH.SignIn)
|
||||
@@ -52,15 +53,17 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine {
|
||||
r.DELETE("/users/:id", authMW, adminOnly, usersH.Delete)
|
||||
|
||||
r.GET("/devices", authMW, middleware.DeviceAccessFilter(), devH.List)
|
||||
r.POST("/devices/create", authMW, devH.Create)
|
||||
r.POST("/devices/create", authMW, adminOnly, devH.Create)
|
||||
r.POST("/devices/:guid/rename", authMW, devH.Rename)
|
||||
r.POST("/devices/:guid/add_to_user", authMW, devH.AddToUser)
|
||||
r.POST("/devices/:guid/set_users", authMW, adminOnly, devH.SetUsers)
|
||||
r.POST("/devices/:guid/remove_from_user", authMW, devH.RemoveFromUser)
|
||||
r.POST("/device/:guid/task", authMW, middleware.DeviceAccessFilter(), tasksH.CreateTask)
|
||||
r.GET("/device/:guid/tasks", authMW, middleware.DeviceAccessFilter(), tasksH.ListDeviceTasks)
|
||||
r.GET("/device/:guid/certs", authMW, adminOnly, devH.ListCertsByDevice)
|
||||
r.POST("/certs/revoke", authMW, adminOnly, certsAdminH.Revoke)
|
||||
|
||||
r.POST("/records/upload", recH.Upload)
|
||||
r.POST("/records/upload", middleware.MTLSGuardUpload(db), recH.Upload)
|
||||
r.GET("/records", authMW, recH.List)
|
||||
r.GET("/records/:id/file", authMW, recH.File)
|
||||
|
||||
@@ -86,9 +89,11 @@ func Build(db *gorm.DB, minio *minio.Client, cfg *config.Config) *gin.Engine {
|
||||
r.POST("/trackers/:guid/set_users", authMW, adminOnly, trackersH.SetUsers)
|
||||
|
||||
// --- Device Job/Task API
|
||||
r.GET("/tasks/:guid", tasksH.DeviceNextTask) // heartbeat + fetch next task
|
||||
r.POST("/tasks/:guid", tasksH.DevicePostResult) // device posts result
|
||||
r.GET("/tasks/:guid", middleware.MTLSGuard(db), tasksH.DeviceNextTask) // heartbeat + fetch next task
|
||||
r.POST("/tasks/:guid", middleware.MTLSGuard(db), tasksH.DevicePostResult) // device posts result
|
||||
|
||||
r.POST("/enroll/:guid", certsH.Enroll) // simple device-exists check is inside handler
|
||||
r.POST("/renew/:guid", middleware.MTLSGuard(db), certsH.Renew)
|
||||
// sensible defaults
|
||||
r.MaxMultipartMemory = 64 << 20 // 64 MiB
|
||||
_ = time.Now() // appease linters
|
||||
|
||||
86
server/internal/vault/pki.go
Normal file
86
server/internal/vault/pki.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
vault "github.com/hashicorp/vault-client-go"
|
||||
)
|
||||
|
||||
type PKIClient struct {
|
||||
client *vault.Client
|
||||
mount string // e.g. "pki_iot"
|
||||
role string // e.g. "device"
|
||||
}
|
||||
|
||||
type SignResponse struct {
|
||||
Certificate string `json:"certificate"`
|
||||
IssuingCA string `json:"issuing_ca"`
|
||||
CAChain []string `json:"ca_chain"`
|
||||
PrivateKey string `json:"private_key,omitempty"`
|
||||
PrivateKeyType string `json:"private_key_type,omitempty"`
|
||||
SerialNumber string `json:"serial_number"`
|
||||
}
|
||||
|
||||
func NewPKI(addr, token, mount, role string, timeout time.Duration) (*PKIClient, error) {
|
||||
client, err := vault.New(
|
||||
vault.WithAddress(addr),
|
||||
vault.WithRequestTimeout(timeout),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vault new: %w", err)
|
||||
}
|
||||
if err := client.SetToken(token); err != nil {
|
||||
return nil, fmt.Errorf("set token: %w", err)
|
||||
}
|
||||
return &PKIClient{client: client, mount: mount, role: role}, nil
|
||||
}
|
||||
|
||||
// SignCSR calls: /v1/<mount>/sign/<role>
|
||||
func (p *PKIClient) SignCSR(ctx context.Context, csrPEM, uriSAN string, ttl string) (*SignResponse, error) {
|
||||
if p.client == nil {
|
||||
return nil, fmt.Errorf("vault client is nil")
|
||||
}
|
||||
path := fmt.Sprintf("/%s/sign/%s", p.mount, p.role)
|
||||
req := map[string]any{
|
||||
"csr": csrPEM,
|
||||
"uri_sans": uriSAN, // e.g. "urn:device:<GUID>"
|
||||
}
|
||||
if ttl != "" {
|
||||
req["ttl"] = ttl // e.g. "720h"
|
||||
}
|
||||
|
||||
resp, err := p.client.Write(ctx, path, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil || resp.Data == nil {
|
||||
return nil, fmt.Errorf("vault sign: empty response")
|
||||
}
|
||||
|
||||
// resp.Data contains the fields we need
|
||||
var out SignResponse
|
||||
if err := mapToStruct(resp.Data, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// RebuildCRL triggers CRL regeneration (rotate).
|
||||
// HCP: POST /v1/<mount>/crl/rotate or read /crl to force render; we’ll call rotate when available,
|
||||
// else read to ensure CRL is (re)generated.
|
||||
func (p *PKIClient) RebuildCRL(ctx context.Context) error {
|
||||
_, _ = p.client.Write(ctx, fmt.Sprintf("/%s/crl/rotate", p.mount), nil) // best effort
|
||||
_, err := p.client.Read(ctx, fmt.Sprintf("/%s/crl", p.mount))
|
||||
return err
|
||||
}
|
||||
|
||||
func mapToStruct(m map[string]any, out any) error {
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, out)
|
||||
}
|
||||
Reference in New Issue
Block a user