created handlers for certificate manipulation in vault. Inserted device mTLS guards for public faced endpoints
This commit is contained in:
152
server/internal/middleware/mtls.go
Normal file
152
server/internal/middleware/mtls.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"smoop-api/internal/models"
|
||||
)
|
||||
|
||||
// Common helpers ---------------------------------------------------------------
|
||||
|
||||
func getHeader(c *gin.Context, k string) string {
|
||||
return strings.TrimSpace(c.GetHeader(k))
|
||||
}
|
||||
|
||||
func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) {
|
||||
if escaped == "" {
|
||||
return nil, errors.New("empty client cert header")
|
||||
}
|
||||
raw, err := url.QueryUnescape(escaped)
|
||||
if err != nil {
|
||||
// Some Nginx builds already space-escape; still try raw
|
||||
raw = escaped
|
||||
}
|
||||
block, _ := pem.Decode([]byte(raw))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM in client cert header")
|
||||
}
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
|
||||
func extractGuidFromSAN(cert *x509.Certificate) (string, bool) {
|
||||
for _, u := range cert.URIs {
|
||||
if strings.EqualFold(u.Scheme, "urn") {
|
||||
// urn:device:<GUID> => u.Opaque == "device:<GUID>"
|
||||
parts := strings.SplitN(u.Opaque, ":", 2)
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "device") {
|
||||
return parts[1], true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func normalizeSerialHex(b []byte) string {
|
||||
h := strings.ToUpper(hex.EncodeToString(b))
|
||||
// remove leading zeros normalization optional, keep as-is
|
||||
return h
|
||||
}
|
||||
|
||||
func isSerialRevoked(db *gorm.DB, serialHex string) (bool, error) {
|
||||
var cnt int64
|
||||
if err := db.Model(&models.RevokedSerial{}).Where("serial_hex = ?", serialHex).Count(&cnt).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cnt > 0, nil
|
||||
}
|
||||
|
||||
// Guard for /tasks/:guid and /renew/:guid -------------------------------------
|
||||
// Requires Nginx to forward:
|
||||
//
|
||||
// X-SSL-Client-Verify, X-SSL-Client-Serial, X-SSL-Client-Cert
|
||||
func MTLSGuard(db *gorm.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
|
||||
c.AbortWithStatus(401)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse client cert (for GUID extraction) and serial
|
||||
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
|
||||
return
|
||||
}
|
||||
pathGuid := strings.TrimSpace(c.Param("guid"))
|
||||
certGuid, ok := extractGuidFromSAN(cert)
|
||||
if !ok || certGuid == "" {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
|
||||
return
|
||||
}
|
||||
if pathGuid == "" || !strings.EqualFold(pathGuid, certGuid) {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
|
||||
rev, err := isSerialRevoked(db, serialHex)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
|
||||
return
|
||||
}
|
||||
if rev {
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
|
||||
// Stash for downstream (device + serial)
|
||||
c.Set("mtlsDeviceGUID", certGuid)
|
||||
c.Set("mtlsSerialHex", serialHex)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Guard for /records/upload (guid comes as form field instead of path)
|
||||
func MTLSGuardUpload(db *gorm.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
|
||||
c.AbortWithStatus(401)
|
||||
return
|
||||
}
|
||||
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
|
||||
return
|
||||
}
|
||||
certGuid, ok := extractGuidFromSAN(cert)
|
||||
if !ok || certGuid == "" {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
|
||||
return
|
||||
}
|
||||
formGuid := strings.TrimSpace(c.PostForm("guid"))
|
||||
if formGuid == "" {
|
||||
// allow query fallback if you want
|
||||
formGuid = strings.TrimSpace(c.Query("guid"))
|
||||
}
|
||||
if !strings.EqualFold(formGuid, certGuid) {
|
||||
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
|
||||
return
|
||||
}
|
||||
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
|
||||
rev, err := isSerialRevoked(db, serialHex)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
|
||||
return
|
||||
}
|
||||
if rev {
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
c.Set("mtlsDeviceGUID", certGuid)
|
||||
c.Set("mtlsSerialHex", serialHex)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user