Files
BoostAI/Backend/internal/middleware/auth.go
2026-05-25 17:05:06 +01:00

194 lines
4.5 KiB
Go

// Path: Backend/internal/middleware/auth.go
package middleware
import (
"errors"
"strconv"
"strings"
"time"
"boostai-backend/internal/config"
"boostai-backend/internal/sqlc"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v5"
)
const DefaultTokenTTL = 7 * 24 * time.Hour
type AuthClaims struct {
UserID int64 `json:"user_id"`
Role sqlc.UserRole `json:"role"`
Email string `json:"email"`
jwt.RegisteredClaims
}
type AuthMiddleware struct {
cfg *config.Config
}
func NewAuthMiddleware(cfg *config.Config) *AuthMiddleware {
return &AuthMiddleware{cfg: cfg}
}
func (m *AuthMiddleware) CreateToken(userID int64, role sqlc.UserRole, email string, ttl time.Duration) (string, error) {
if ttl <= 0 {
ttl = DefaultTokenTTL
}
now := time.Now().UTC()
claims := AuthClaims{
UserID: userID,
Role: role,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
Subject: email,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.cfg.JWTSecret))
}
func (m *AuthMiddleware) RequireAuth() fiber.Handler {
return func(c *fiber.Ctx) error {
claims, err := m.parseClaims(c)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "unauthorized",
"message": "Authentication required",
})
}
c.Locals("auth.user_id", claims.UserID)
c.Locals("auth.role", claims.Role)
c.Locals("auth.email", claims.Email)
return c.Next()
}
}
func (m *AuthMiddleware) RequireTeacher() fiber.Handler {
return func(c *fiber.Ctx) error {
if CurrentUserRole(c) != sqlc.UserRoleTeacher {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "forbidden",
"message": "Teacher access required",
})
}
return c.Next()
}
}
func (m *AuthMiddleware) RequireTeacherSelf(param string) fiber.Handler {
return func(c *fiber.Ctx) error {
if CurrentUserRole(c) != sqlc.UserRoleTeacher {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "forbidden",
"message": "Teacher access required",
})
}
paramID, err := parsePositiveParam(c, param)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "invalid_request",
"message": "Invalid path parameter: " + param,
})
}
if CurrentUserID(c) != paramID {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "forbidden",
"message": "You can only access your own teacher resources",
})
}
return c.Next()
}
}
func (m *AuthMiddleware) RequireStudentSelfOrTeacher(param string) fiber.Handler {
return func(c *fiber.Ctx) error {
paramID, err := parsePositiveParam(c, param)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "invalid_request",
"message": "Invalid path parameter: " + param,
})
}
if CurrentUserRole(c) == sqlc.UserRoleTeacher || CurrentUserID(c) == paramID {
return c.Next()
}
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "forbidden",
"message": "You can only access your own student resources",
})
}
}
func CurrentUserID(c *fiber.Ctx) int64 {
value, ok := c.Locals("auth.user_id").(int64)
if !ok {
return 0
}
return value
}
func CurrentUserRole(c *fiber.Ctx) sqlc.UserRole {
value, ok := c.Locals("auth.role").(sqlc.UserRole)
if !ok {
return ""
}
return value
}
func (m *AuthMiddleware) parseClaims(c *fiber.Ctx) (*AuthClaims, error) {
tokenValue := strings.TrimSpace(c.Cookies(m.cfg.SessionCookie))
if tokenValue == "" {
authorization := strings.TrimSpace(c.Get("Authorization"))
if strings.HasPrefix(strings.ToLower(authorization), "bearer ") {
tokenValue = strings.TrimSpace(authorization[7:])
}
}
if tokenValue == "" {
return nil, errors.New("missing token")
}
parsed, err := jwt.ParseWithClaims(tokenValue, &AuthClaims{}, func(token *jwt.Token) (any, error) {
if token.Method != jwt.SigningMethodHS256 {
return nil, errors.New("unexpected signing method")
}
return []byte(m.cfg.JWTSecret), nil
})
if err != nil {
return nil, err
}
claims, ok := parsed.Claims.(*AuthClaims)
if !ok || !parsed.Valid {
return nil, errors.New("invalid token")
}
return claims, nil
}
func parsePositiveParam(c *fiber.Ctx, param string) (int64, error) {
value := strings.TrimSpace(c.Params(param))
parsed, err := strconv.ParseInt(value, 10, 64)
if err != nil || parsed <= 0 {
return 0, errors.New("invalid param")
}
return parsed, nil
}