Files
NewSmoop/server/internal/middleware/mtls.go

153 lines
4.1 KiB
Go

package middleware
import (
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"smoop-api/internal/models"
)
// Common helpers ---------------------------------------------------------------
func getHeader(c *gin.Context, k string) string {
return strings.TrimSpace(c.GetHeader(k))
}
func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) {
if escaped == "" {
return nil, errors.New("empty client cert header")
}
raw, err := url.QueryUnescape(escaped)
if err != nil {
// Some Nginx builds already space-escape; still try raw
raw = escaped
}
block, _ := pem.Decode([]byte(raw))
if block == nil {
return nil, errors.New("failed to decode PEM in client cert header")
}
return x509.ParseCertificate(block.Bytes)
}
func extractGuidFromSAN(cert *x509.Certificate) (string, bool) {
for _, u := range cert.URIs {
if strings.EqualFold(u.Scheme, "urn") {
// urn:device:<GUID> => u.Opaque == "device:<GUID>"
parts := strings.SplitN(u.Opaque, ":", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "device") {
return parts[1], true
}
}
}
return "", false
}
func normalizeSerialHex(b []byte) string {
h := strings.ToUpper(hex.EncodeToString(b))
// remove leading zeros normalization optional, keep as-is
return h
}
func isSerialRevoked(db *gorm.DB, serialHex string) (bool, error) {
var cnt int64
if err := db.Model(&models.RevokedSerial{}).Where("serial_hex = ?", serialHex).Count(&cnt).Error; err != nil {
return false, err
}
return cnt > 0, nil
}
// Guard for /tasks/:guid and /renew/:guid -------------------------------------
// Requires Nginx to forward:
//
// X-SSL-Client-Verify, X-SSL-Client-Serial, X-SSL-Client-Cert
func MTLSGuard(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
c.AbortWithStatus(401)
return
}
// Parse client cert (for GUID extraction) and serial
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
if err != nil {
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
return
}
pathGuid := strings.TrimSpace(c.Param("guid"))
certGuid, ok := extractGuidFromSAN(cert)
if !ok || certGuid == "" {
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
return
}
if pathGuid == "" || !strings.EqualFold(pathGuid, certGuid) {
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
return
}
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
rev, err := isSerialRevoked(db, serialHex)
if err != nil {
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
return
}
if rev {
c.AbortWithStatus(403)
return
}
// Stash for downstream (device + serial)
c.Set("mtlsDeviceGUID", certGuid)
c.Set("mtlsSerialHex", serialHex)
c.Next()
}
}
// Guard for /records/upload (guid comes as form field instead of path)
func MTLSGuardUpload(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
c.AbortWithStatus(401)
return
}
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
if err != nil {
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
return
}
certGuid, ok := extractGuidFromSAN(cert)
if !ok || certGuid == "" {
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
return
}
formGuid := strings.TrimSpace(c.PostForm("guid"))
if formGuid == "" {
// allow query fallback if you want
formGuid = strings.TrimSpace(c.Query("guid"))
}
if !strings.EqualFold(formGuid, certGuid) {
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
return
}
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
rev, err := isSerialRevoked(db, serialHex)
if err != nil {
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
return
}
if rev {
c.AbortWithStatus(403)
return
}
c.Set("mtlsDeviceGUID", certGuid)
c.Set("mtlsSerialHex", serialHex)
c.Next()
}
}