package cc.mrbird.febs.gateway.enhance.service.impl; import cc.mrbird.febs.common.core.entity.FebsResponse; import cc.mrbird.febs.common.core.utils.DateUtil; import cc.mrbird.febs.common.core.utils.FebsUtil; import cc.mrbird.febs.gateway.enhance.entity.*; import cc.mrbird.febs.gateway.enhance.service.*; import cc.mrbird.febs.gateway.enhance.utils.AddressUtil; import com.alibaba.fastjson.JSONObject; import com.google.common.base.Stopwatch; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.cloud.gateway.route.Route; import org.springframework.cloud.gateway.support.ServerWebExchangeUtils; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.stereotype.Service; import org.springframework.util.AntPathMatcher; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import java.net.URI; import java.time.LocalTime; import java.util.LinkedHashSet; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; /** * @author MrBird */ @Slf4j @Service @RequiredArgsConstructor public class RouteEnhanceServiceImpl implements RouteEnhanceService { private static final String METHOD_ALL = "ALL"; private static final String TOKEN_CHECK_URL = "/auth/user"; private final RouteLogService routeLogService; private final BlockLogService blockLogService; private final RateLimitLogService rateLimitLogService; private final RouteEnhanceCacheService routeEnhanceCacheService; private final AntPathMatcher pathMatcher = new AntPathMatcher(); @Override public Mono filterBlackList(ServerWebExchange exchange) { Stopwatch stopwatch = Stopwatch.createStarted(); ServerHttpRequest request = exchange.getRequest(); ServerHttpResponse response = exchange.getResponse(); try { URI originUri = getGatewayOriginalRequestUrl(exchange); if (originUri != null) { String requestIp = FebsUtil.getServerHttpRequestIpAddress(request); String requestMethod = request.getMethodValue(); AtomicBoolean forbid = new AtomicBoolean(false); Set blackList = routeEnhanceCacheService.getBlackList(requestIp); blackList.addAll(routeEnhanceCacheService.getBlackList()); doBlackListCheck(forbid, blackList, originUri, requestMethod); log.info("Blacklist verification completed - {}", stopwatch.stop()); if (forbid.get()) { return FebsUtil.makeWebFluxResponse(response, MediaType.APPLICATION_JSON_VALUE, HttpStatus.NOT_ACCEPTABLE, new FebsResponse().message("黑名单限制,禁止访问")); } } else { log.info("Request IP not obtained, no blacklist check - {}", stopwatch.stop()); } } catch (Exception e) { log.warn("Blacklist verification failed : {} - {}", e.getMessage(), stopwatch.stop()); } return null; } @Override public Mono filterRateLimit(ServerWebExchange exchange) { Stopwatch stopwatch = Stopwatch.createStarted(); ServerHttpRequest request = exchange.getRequest(); ServerHttpResponse response = exchange.getResponse(); try { URI originUri = getGatewayOriginalRequestUrl(exchange); if (originUri != null) { String requestIp = FebsUtil.getServerHttpRequestIpAddress(request); String requestMethod = request.getMethodValue(); AtomicBoolean limit = new AtomicBoolean(false); Object o = routeEnhanceCacheService.getRateLimitRule(originUri.getPath(), METHOD_ALL); if (o == null) { o = routeEnhanceCacheService.getRateLimitRule(originUri.getPath(), requestMethod); } if (o != null) { RateLimitRule rule = JSONObject.parseObject(o.toString(), RateLimitRule.class); Mono result = doRateLimitCheck(limit, rule, originUri, requestIp, requestMethod, response); log.info("Rate limit verification completed - {}", stopwatch.stop()); if (result != null) { return result; } } } else { log.info("Request IP not obtained, no rate limit filter - {}", stopwatch.stop()); } } catch (Exception e) { log.warn("Current limit failure : {} - {}", e.getMessage(), stopwatch.stop()); } return null; } @Override public void saveRequestLogs(ServerWebExchange exchange) { URI originUri = getGatewayOriginalRequestUrl(exchange); // /auth/user为令牌校验请求,是系统自发行为,非用户请求,故不记录 if (!StringUtils.equalsIgnoreCase(TOKEN_CHECK_URL, originUri.getPath())) { URI url = getGatewayRequestUrl(exchange); Route route = getGatewayRoute(exchange); ServerHttpRequest request = exchange.getRequest(); String ipAddress = FebsUtil.getServerHttpRequestIpAddress(request); if (url != null && route != null) { RouteLog routeLog = RouteLog.builder() .ip(ipAddress) .requestUri(originUri.getPath()) .targetServer(route.getId()) .targetUri(url.getPath()) .requestMethod(request.getMethodValue()) .location(AddressUtil.getCityInfo(ipAddress)) .build(); routeLogService.create(routeLog).subscribe(); } } } @Override public void saveBlockLogs(ServerWebExchange exchange) { URI originUri = getGatewayOriginalRequestUrl(exchange); ServerHttpRequest request = exchange.getRequest(); String requestIp = FebsUtil.getServerHttpRequestIpAddress(request); if (originUri != null) { BlockLog blockLog = BlockLog.builder() .ip(requestIp) .requestMethod(request.getMethodValue()) .requestUri(originUri.getPath()) .build(); blockLogService.create(blockLog).subscribe(); log.info("Store blocked request logs >>>"); } } @Override public void saveRateLimitLogs(ServerWebExchange exchange) { URI originUri = getGatewayOriginalRequestUrl(exchange); ServerHttpRequest request = exchange.getRequest(); String requestIp = FebsUtil.getServerHttpRequestIpAddress(request); if (originUri != null) { RateLimitLog rateLimitLog = RateLimitLog.builder() .ip(requestIp) .requestMethod(request.getMethodValue()) .requestUri(originUri.getPath()) .build(); rateLimitLogService.create(rateLimitLog).subscribe(); log.info("Store rate limit logs >>>"); } } private void doBlackListCheck(AtomicBoolean forbid, Set blackList, URI uri, String requestMethod) { for (Object o : blackList) { BlackList b = JSONObject.parseObject(o.toString(), BlackList.class); if (pathMatcher.match(b.getRequestUri(), uri.getPath()) && BlackList.OPEN == Integer.parseInt(b.getStatus())) { if (BlackList.METHOD_ALL.equalsIgnoreCase(b.getRequestMethod()) || StringUtils.equalsIgnoreCase(requestMethod, b.getRequestMethod())) { if (StringUtils.isNotBlank(b.getLimitFrom()) && StringUtils.isNotBlank(b.getLimitTo())) { if (DateUtil.between(LocalTime.parse(b.getLimitFrom()), LocalTime.parse(b.getLimitTo()))) { forbid.set(true); } } else { forbid.set(true); } } } if (forbid.get()) { break; } } } private Mono doRateLimitCheck(AtomicBoolean limit, RateLimitRule rule, URI uri, String requestIp, String requestMethod, ServerHttpResponse response) { boolean isRateLimitRuleHit = RateLimitRule.OPEN == Integer.parseInt(rule.getStatus()) &&(RateLimitRule.METHOD_ALL.equalsIgnoreCase(rule.getRequestMethod()) || StringUtils.equalsIgnoreCase(requestMethod, rule.getRequestMethod())); if (isRateLimitRuleHit) { if (StringUtils.isNotBlank(rule.getLimitFrom()) && StringUtils.isNotBlank(rule.getLimitTo())) { if (DateUtil.between(LocalTime.parse(rule.getLimitFrom()), LocalTime.parse(rule.getLimitTo()))) { limit.set(true); } } else { limit.set(true); } } if (limit.get()) { String requestUri = uri.getPath(); int count = routeEnhanceCacheService.getCurrentRequestCount(requestUri, requestIp); if (count == 0) { routeEnhanceCacheService.setCurrentRequestCount(requestUri, requestIp, Long.parseLong(rule.getIntervalSec())); } else if (count >= Integer.parseInt(rule.getCount())) { return FebsUtil.makeWebFluxResponse(response, MediaType.APPLICATION_JSON_VALUE, HttpStatus.TOO_MANY_REQUESTS, new FebsResponse().message("访问频率超限,请稍后再试")); } else { routeEnhanceCacheService.incrCurrentRequestCount(requestUri, requestIp); } } return null; } private URI getGatewayOriginalRequestUrl(ServerWebExchange exchange) { LinkedHashSet uris = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR); URI originUri = null; if (uris != null) { originUri = uris.stream().findFirst().orElse(null); } return originUri; } private URI getGatewayRequestUrl(ServerWebExchange exchange) { return exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR); } private Route getGatewayRoute(ServerWebExchange exchange) { return exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR); } }