194 lines
4.5 KiB
Go
194 lines
4.5 KiB
Go
// Path: Backend/internal/middleware/auth.go
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"boostai-backend/internal/config"
|
|
"boostai-backend/internal/sqlc"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
const DefaultTokenTTL = 7 * 24 * time.Hour
|
|
|
|
type AuthClaims struct {
|
|
UserID int64 `json:"user_id"`
|
|
Role sqlc.UserRole `json:"role"`
|
|
Email string `json:"email"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type AuthMiddleware struct {
|
|
cfg *config.Config
|
|
}
|
|
|
|
func NewAuthMiddleware(cfg *config.Config) *AuthMiddleware {
|
|
return &AuthMiddleware{cfg: cfg}
|
|
}
|
|
|
|
func (m *AuthMiddleware) CreateToken(userID int64, role sqlc.UserRole, email string, ttl time.Duration) (string, error) {
|
|
if ttl <= 0 {
|
|
ttl = DefaultTokenTTL
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
claims := AuthClaims{
|
|
UserID: userID,
|
|
Role: role,
|
|
Email: email,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: email,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString([]byte(m.cfg.JWTSecret))
|
|
}
|
|
|
|
func (m *AuthMiddleware) RequireAuth() fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
claims, err := m.parseClaims(c)
|
|
if err != nil {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
|
"error": "unauthorized",
|
|
"message": "Authentication required",
|
|
})
|
|
}
|
|
|
|
c.Locals("auth.user_id", claims.UserID)
|
|
c.Locals("auth.role", claims.Role)
|
|
c.Locals("auth.email", claims.Email)
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (m *AuthMiddleware) RequireTeacher() fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
if CurrentUserRole(c) != sqlc.UserRoleTeacher {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": "forbidden",
|
|
"message": "Teacher access required",
|
|
})
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (m *AuthMiddleware) RequireTeacherSelf(param string) fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
if CurrentUserRole(c) != sqlc.UserRoleTeacher {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": "forbidden",
|
|
"message": "Teacher access required",
|
|
})
|
|
}
|
|
|
|
paramID, err := parsePositiveParam(c, param)
|
|
if err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
|
"error": "invalid_request",
|
|
"message": "Invalid path parameter: " + param,
|
|
})
|
|
}
|
|
|
|
if CurrentUserID(c) != paramID {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": "forbidden",
|
|
"message": "You can only access your own teacher resources",
|
|
})
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (m *AuthMiddleware) RequireStudentSelfOrTeacher(param string) fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
paramID, err := parsePositiveParam(c, param)
|
|
if err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
|
"error": "invalid_request",
|
|
"message": "Invalid path parameter: " + param,
|
|
})
|
|
}
|
|
|
|
if CurrentUserRole(c) == sqlc.UserRoleTeacher || CurrentUserID(c) == paramID {
|
|
return c.Next()
|
|
}
|
|
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": "forbidden",
|
|
"message": "You can only access your own student resources",
|
|
})
|
|
}
|
|
}
|
|
|
|
func CurrentUserID(c *fiber.Ctx) int64 {
|
|
value, ok := c.Locals("auth.user_id").(int64)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
|
|
return value
|
|
}
|
|
|
|
func CurrentUserRole(c *fiber.Ctx) sqlc.UserRole {
|
|
value, ok := c.Locals("auth.role").(sqlc.UserRole)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
|
|
return value
|
|
}
|
|
|
|
func (m *AuthMiddleware) parseClaims(c *fiber.Ctx) (*AuthClaims, error) {
|
|
tokenValue := strings.TrimSpace(c.Cookies(m.cfg.SessionCookie))
|
|
if tokenValue == "" {
|
|
authorization := strings.TrimSpace(c.Get("Authorization"))
|
|
if strings.HasPrefix(strings.ToLower(authorization), "bearer ") {
|
|
tokenValue = strings.TrimSpace(authorization[7:])
|
|
}
|
|
}
|
|
|
|
if tokenValue == "" {
|
|
return nil, errors.New("missing token")
|
|
}
|
|
|
|
parsed, err := jwt.ParseWithClaims(tokenValue, &AuthClaims{}, func(token *jwt.Token) (any, error) {
|
|
if token.Method != jwt.SigningMethodHS256 {
|
|
return nil, errors.New("unexpected signing method")
|
|
}
|
|
|
|
return []byte(m.cfg.JWTSecret), nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
claims, ok := parsed.Claims.(*AuthClaims)
|
|
if !ok || !parsed.Valid {
|
|
return nil, errors.New("invalid token")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
func parsePositiveParam(c *fiber.Ctx, param string) (int64, error) {
|
|
value := strings.TrimSpace(c.Params(param))
|
|
parsed, err := strconv.ParseInt(value, 10, 64)
|
|
if err != nil || parsed <= 0 {
|
|
return 0, errors.New("invalid param")
|
|
}
|
|
|
|
return parsed, nil
|
|
}
|