main.go
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
package main

import (
	"context"
	"fmt"
	"time"

	"github.com/golang-jwt/jwt/v5"
	"github.com/redis/go-redis/v9"
)

// JWTManger handles JWT creation, validation, and revocation
type JWTManger struct {
	secretKey     []byte
	redisClient   *redis.Client
	blacklistKey  string
	defaultTTL    time.Duration
	ctx           context.Context
}

// UserClaims represents the custom JWT claims with user data
type UserClaims struct {
	UserID string `json:"user_id"`
	Email  string `json:"email"`
	jwt.RegisteredClaims
}

// NewJWTManger creates a new JWT manager with Redis blacklist
func NewJWTManger(secretKey string, redisClient *redis.Client, defaultTTL time.Duration) *JWTManger {
	return &JWTManger{
		secretKey:     []byte(secretKey),
		redisClient:   redisClient,
		blacklistKey:  "jwt:blacklist",
		defaultTTL:    defaultTTL,
		ctx:           context.Background(),
	}
}

// GenerateToken creates a new JWT token for a user
func (m *JWTManger) GenerateToken(userID, email string) (string, error) {
	expirationTime := time.Now().Add(m.defaultTTL)

	// Create claims
	claims := &UserClaims{
		UserID: userID,
		Email:  email,
		RegisteredClaims: jwt.RegisteredClaims{
			ExpiresAt: jwt.NewNumericDate(expirationTime),
			IssuedAt:  jwt.NewNumericDate(time.Now()),
			NotBefore: jwt.NewNumericDate(time.Now()),
			Issuer:    "go-jwt-manager",
		},
	}

	// Create token with HS256 signing method
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

	// Sign token with secret key
	tokenString, err := token.SignedString(m.secretKey)
	if err != nil {
		return "", fmt.Errorf("failed to sign token: %w", err)
	}

	return tokenString, nil
}

// ValidateToken checks if a token is valid and not revoked
func (m *JWTManger) ValidateToken(tokenString string) (*UserClaims, error) {
	// Parse token and validate signature
	token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
		// Validate signing method
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}
		return m.secretKey, nil
	})

	if err != nil {
		return nil, fmt.Errorf("token parsing failed: %w", err)
	}

	// Check if token is valid
	if !token.Valid {
		return nil, fmt.Errorf("invalid token")
	}

	// Extract claims
	claims, ok := token.Claims.(*UserClaims)
	if !ok {
		return nil, fmt.Errorf("failed to extract claims")
	}

	// Check if token is revoked
	isRevoked, err := m.IsTokenRevoked(tokenString)
	if err != nil {
		return nil, fmt.Errorf("failed to check revocation status: %w", err)
	}

	if isRevoked {
		return nil, fmt.Errorf("token has been revoked")
	}

	// Check if token is expired (redundant but safe)
	if claims.ExpiresAt.Before(time.Now()) {
		return nil, fmt.Errorf("token has expired")
	}

	return claims, nil
}

// RevokeToken adds a token to the Redis blacklist until expiration
func (m *JWTManger) RevokeToken(tokenString string) error {
	// Parse token to get expiration time (without validating signature)
	token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &UserClaims{})
	if err != nil {
		return fmt.Errorf("failed to parse token: %w", err)
	}

	// Extract claims
	claims, ok := token.Claims.(*UserClaims)
	if !ok {
		return fmt.Errorf("failed to extract claims")
	}

	// Calculate TTL for blacklist entry (time until token expires)
	var ttl time.Duration
	if claims.ExpiresAt != nil {
		ttl = time.Until(claims.ExpiresAt.Time)
		if ttl < 0 {
			// Token already expired - no need to blacklist
			return nil
		}
	} else {
		// Token has no expiration - use default TTL
		ttl = m.defaultTTL
	}

	// Add token to Redis blacklist with TTL
	err = m.redisClient.Set(m.ctx, m.getBlacklistKey(tokenString), "revoked", ttl).Err()
	if err != nil {
		return fmt.Errorf("failed to add to blacklist: %w", err)
	}

	return nil
}

// IsTokenRevoked checks if a token is in the blacklist
func (m *JWTManger) IsTokenRevoked(tokenString string) (bool, error) {
	// Check if token exists in Redis blacklist
	val, err := m.redisClient.Get(m.ctx, m.getBlacklistKey(tokenString)).Result()
	if err == redis.Nil {
		// Token not in blacklist
		return false, nil
	}
	if err != nil {
		return false, fmt.Errorf("redis error: %w", err)
	}

	return val == "revoked", nil
}

// getBlacklistKey generates a unique key for the token in Redis
func (m *JWTManger) getBlacklistKey(tokenString string) string {
	return fmt.Sprintf("%s:%s", m.blacklistKey, tokenString)
}

func main() {
	// Initialize Redis client
	redisClient := redis.NewClient(&redis.Options{
		Addr: "localhost:6379",
	})

	// Test Redis connection
	_, err := redisClient.Ping(context.Background()).Result()
	if err != nil {
		fmt.Printf("Redis connection failed: %v\n", err)
		return
	}

	// Create JWT manager with 1-hour default TTL
	manager := NewJWTManger("your-strong-secret-key-keep-it-safe", redisClient, 1*time.Hour)

	// Generate token for user
	token, err := manager.GenerateToken("user_123", "[email protected]")
	if err != nil {
		fmt.Printf("Failed to generate token: %v\n", err)
		return
	}
	fmt.Printf("Generated Token: %s\n\n", token)

	// Validate token (should be valid)
	claims, err := manager.ValidateToken(token)
	if err != nil {
		fmt.Printf("Token validation failed: %v\n", err)
		return
	}
	fmt.Printf("Valid Token - UserID: %s, Email: %s, Expires: %v\n\n", claims.UserID, claims.Email, claims.ExpiresAt.Time)

	// Revoke the token
	fmt.Println("Revoking token...")
	if err := manager.RevokeToken(token); err != nil {
		fmt.Printf("Failed to revoke token: %v\n", err)
		return
	}

	// Try to validate revoked token (should fail)
	_, err = manager.ValidateToken(token)
	if err != nil {
		fmt.Printf("Token validation failed as expected: %v\n", err)
	}
}

How It Works

Adds JWT revocation by storing unique token identifiers in Redis and checking them on each request.

When issuing tokens, assigns a unique ID; the revocation endpoint writes the ID to Redis with expiration matching the token; middleware validates signature and expiry, then queries Redis to reject revoked tokens.

Key Concepts

  • 1Redis blacklist mirrors token TTL to avoid unbounded growth.
  • 2Middleware short-circuits requests with 401 when the token ID is present.
  • 3Separation of issuance and revocation flows keeps logic clear.

When to Use This Pattern

  • Logging out sessions immediately instead of waiting for expiry.
  • Invalidating compromised tokens after user password resets.
  • Revoking API keys in distributed systems without shared state.

Best Practices

  • Use unique jti claims and store them atomically with expiry.
  • Handle Redis outages defensively and fail closed for sensitive endpoints.
  • Rotate signing keys carefully to avoid invalidating all users unnecessarily.
Go Version1.19
Difficultyadvanced
Production ReadyYes
Lines of Code211