diff --git a/pom.xml b/pom.xml index 47ef3fbebd89aae778270c21e474f172221e4a6e..97d9de582ee373db419de143cbd53fb36cdea3eb 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ org.yeauty netty-websocket-spring-boot-starter - 0.12.0 + 0.13.0 netty-websocket-spring-boot-starter @@ -36,8 +36,8 @@ - 4.1.67.Final - 2.0.0.RELEASE + 4.1.116.Final + 2.3.12.RELEASE @@ -58,6 +58,11 @@ netty-handler ${netty.version} + + javax.servlet + javax.servlet-api + 4.0.1 + diff --git a/src/main/java/org/yeauty/pojo/PojoEndpointServer.java b/src/main/java/org/yeauty/pojo/PojoEndpointServer.java index f1fd925971396abdae084950fcc8476f409e18c9..92fc6a8c4e3ffea71df9e8367c238851962b551a 100755 --- a/src/main/java/org/yeauty/pojo/PojoEndpointServer.java +++ b/src/main/java/org/yeauty/pojo/PojoEndpointServer.java @@ -1,257 +1,271 @@ -package org.yeauty.pojo; - -import io.netty.channel.Channel; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketFrame; -import io.netty.util.Attribute; -import io.netty.util.AttributeKey; -import io.netty.util.internal.logging.InternalLogger; -import io.netty.util.internal.logging.InternalLoggerFactory; -import org.springframework.beans.TypeMismatchException; -import org.yeauty.standard.ServerEndpointConfig; -import org.yeauty.support.*; - -import java.lang.reflect.Method; -import java.util.*; - -/** - * @author Yeauty - * @version 1.0 - */ -public class PojoEndpointServer { - - private static final AttributeKey POJO_KEY = AttributeKey.valueOf("WEBSOCKET_IMPLEMENT"); - - public static final AttributeKey SESSION_KEY = AttributeKey.valueOf("WEBSOCKET_SESSION"); - - private static final AttributeKey PATH_KEY = AttributeKey.valueOf("WEBSOCKET_PATH"); - - public static final AttributeKey> URI_TEMPLATE = AttributeKey.valueOf("WEBSOCKET_URI_TEMPLATE"); - - public static final AttributeKey>> REQUEST_PARAM = AttributeKey.valueOf("WEBSOCKET_REQUEST_PARAM"); - - private final Map pathMethodMappingMap = new HashMap<>(); - - private final ServerEndpointConfig config; - - private Set pathMatchers = new HashSet<>(); - - private static final InternalLogger logger = InternalLoggerFactory.getInstance(PojoEndpointServer.class); - - public PojoEndpointServer(PojoMethodMapping methodMapping, ServerEndpointConfig config, String path) { - addPathPojoMethodMapping(path, methodMapping); - this.config = config; - } - - public boolean hasBeforeHandshake(Channel channel, String path) { - PojoMethodMapping methodMapping = getPojoMethodMapping(path, channel); - return methodMapping.getBeforeHandshake()!=null; - } - - public void doBeforeHandshake(Channel channel, FullHttpRequest req, String path) { - PojoMethodMapping methodMapping = null; - methodMapping = getPojoMethodMapping(path, channel); - - Object implement = null; - try { - implement = methodMapping.getEndpointInstance(); - } catch (Exception e) { - logger.error(e); - return; - } - channel.attr(POJO_KEY).set(implement); - Session session = new Session(channel); - channel.attr(SESSION_KEY).set(session); - Method beforeHandshake = methodMapping.getBeforeHandshake(); - if (beforeHandshake != null) { - try { - beforeHandshake.invoke(implement, methodMapping.getBeforeHandshakeArgs(channel, req)); - } catch (TypeMismatchException e) { - throw e; - } catch (Throwable t) { - logger.error(t); - } - } - } - - public void doOnOpen(Channel channel, FullHttpRequest req, String path) { - PojoMethodMapping methodMapping = getPojoMethodMapping(path, channel); - - Object implement = channel.attr(POJO_KEY).get(); - if (implement==null){ - try { - implement = methodMapping.getEndpointInstance(); - channel.attr(POJO_KEY).set(implement); - } catch (Exception e) { - logger.error(e); - return; - } - Session session = new Session(channel); - channel.attr(SESSION_KEY).set(session); - } - - Method onOpenMethod = methodMapping.getOnOpen(); - if (onOpenMethod != null) { - try { - onOpenMethod.invoke(implement, methodMapping.getOnOpenArgs(channel, req)); - } catch (TypeMismatchException e) { - throw e; - } catch (Throwable t) { - logger.error(t); - } - } - } - - public void doOnClose(Channel channel) { - Attribute attrPath = channel.attr(PATH_KEY); - PojoMethodMapping methodMapping = null; - if (pathMethodMappingMap.size() == 1) { - methodMapping = pathMethodMappingMap.values().iterator().next(); - } else { - String path = attrPath.get(); - methodMapping = pathMethodMappingMap.get(path); - if (methodMapping == null) { - return; - } - } - if (methodMapping.getOnClose() != null) { - if (!channel.hasAttr(SESSION_KEY)) { - return; - } - Object implement = channel.attr(POJO_KEY).get(); - try { - methodMapping.getOnClose().invoke(implement, - methodMapping.getOnCloseArgs(channel)); - } catch (Throwable t) { - logger.error(t); - } - } - } - - - public void doOnError(Channel channel, Throwable throwable) { - Attribute attrPath = channel.attr(PATH_KEY); - PojoMethodMapping methodMapping = null; - if (pathMethodMappingMap.size() == 1) { - methodMapping = pathMethodMappingMap.values().iterator().next(); - } else { - String path = attrPath.get(); - methodMapping = pathMethodMappingMap.get(path); - } - if (methodMapping.getOnError() != null) { - if (!channel.hasAttr(SESSION_KEY)) { - return; - } - Object implement = channel.attr(POJO_KEY).get(); - try { - Method method = methodMapping.getOnError(); - Object[] args = methodMapping.getOnErrorArgs(channel, throwable); - method.invoke(implement, args); - } catch (Throwable t) { - logger.error(t); - } - } - } - - public void doOnMessage(Channel channel, WebSocketFrame frame) { - Attribute attrPath = channel.attr(PATH_KEY); - PojoMethodMapping methodMapping = null; - if (pathMethodMappingMap.size() == 1) { - methodMapping = pathMethodMappingMap.values().iterator().next(); - } else { - String path = attrPath.get(); - methodMapping = pathMethodMappingMap.get(path); - } - if (methodMapping.getOnMessage() != null) { - TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; - Object implement = channel.attr(POJO_KEY).get(); - try { - methodMapping.getOnMessage().invoke(implement, methodMapping.getOnMessageArgs(channel, textFrame)); - } catch (Throwable t) { - logger.error(t); - } - } - } - - public void doOnBinary(Channel channel, WebSocketFrame frame) { - Attribute attrPath = channel.attr(PATH_KEY); - PojoMethodMapping methodMapping = null; - if (pathMethodMappingMap.size() == 1) { - methodMapping = pathMethodMappingMap.values().iterator().next(); - } else { - String path = attrPath.get(); - methodMapping = pathMethodMappingMap.get(path); - } - if (methodMapping.getOnBinary() != null) { - BinaryWebSocketFrame binaryWebSocketFrame = (BinaryWebSocketFrame) frame; - Object implement = channel.attr(POJO_KEY).get(); - try { - methodMapping.getOnBinary().invoke(implement, methodMapping.getOnBinaryArgs(channel, binaryWebSocketFrame)); - } catch (Throwable t) { - logger.error(t); - } - } - } - - public void doOnEvent(Channel channel, Object evt) { - Attribute attrPath = channel.attr(PATH_KEY); - PojoMethodMapping methodMapping = null; - if (pathMethodMappingMap.size() == 1) { - methodMapping = pathMethodMappingMap.values().iterator().next(); - } else { - String path = attrPath.get(); - methodMapping = pathMethodMappingMap.get(path); - } - if (methodMapping.getOnEvent() != null) { - if (!channel.hasAttr(SESSION_KEY)) { - return; - } - Object implement = channel.attr(POJO_KEY).get(); - try { - methodMapping.getOnEvent().invoke(implement, methodMapping.getOnEventArgs(channel, evt)); - } catch (Throwable t) { - logger.error(t); - } - } - } - - public String getHost() { - return config.getHost(); - } - - public int getPort() { - return config.getPort(); - } - - public Set getPathMatcherSet() { - return pathMatchers; - } - - public void addPathPojoMethodMapping(String path, PojoMethodMapping pojoMethodMapping) { - pathMethodMappingMap.put(path, pojoMethodMapping); - for (MethodArgumentResolver onOpenArgResolver : pojoMethodMapping.getOnOpenArgResolvers()) { - if (onOpenArgResolver instanceof PathVariableMethodArgumentResolver || onOpenArgResolver instanceof PathVariableMapMethodArgumentResolver) { - pathMatchers.add(new AntPathMatcherWrapper(path)); - return; - } - } - pathMatchers.add(new DefaultPathMatcher(path)); - } - - private PojoMethodMapping getPojoMethodMapping(String path, Channel channel) { - PojoMethodMapping methodMapping; - if (pathMethodMappingMap.size() == 1) { - methodMapping = pathMethodMappingMap.values().iterator().next(); - } else { - Attribute attrPath = channel.attr(PATH_KEY); - attrPath.set(path); - methodMapping = pathMethodMappingMap.get(path); - if (methodMapping == null) { - throw new RuntimeException("path " + path + " is not in pathMethodMappingMap "); - } - } - return methodMapping; - } -} +package org.yeauty.pojo; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.springframework.beans.TypeMismatchException; +import org.yeauty.standard.ServerEndpointConfig; +import org.yeauty.support.*; + +import java.lang.reflect.Method; +import java.util.*; + +/** + * @author Yeauty + * @version 1.0 + */ +public class PojoEndpointServer { + + private static final AttributeKey POJO_KEY = AttributeKey.valueOf("WEBSOCKET_IMPLEMENT"); + + public static final AttributeKey SESSION_KEY = AttributeKey.valueOf("WEBSOCKET_SESSION"); + + private static final AttributeKey PATH_KEY = AttributeKey.valueOf("WEBSOCKET_PATH"); + + public static final AttributeKey> URI_TEMPLATE = AttributeKey.valueOf("WEBSOCKET_URI_TEMPLATE"); + + public static final AttributeKey>> REQUEST_PARAM = AttributeKey.valueOf("WEBSOCKET_REQUEST_PARAM"); + + private final Map pathMethodMappingMap = new HashMap<>(); + + private final ServerEndpointConfig config; + + private Set pathMatchers = new HashSet<>(); + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PojoEndpointServer.class); + + public PojoEndpointServer(PojoMethodMapping methodMapping, ServerEndpointConfig config, String path) { + addPathPojoMethodMapping(path, methodMapping); + this.config = config; + } + + public boolean hasBeforeHandshake(Channel channel, String path) { + PojoMethodMapping methodMapping = getPojoMethodMapping(path, channel); + return methodMapping.getBeforeHandshake()!=null; + } + + public void doBeforeHandshake(Channel channel, FullHttpRequest req, String path) { + PojoMethodMapping methodMapping = null; + methodMapping = getPojoMethodMapping(path, channel); + + Object implement = null; + try { + implement = methodMapping.getEndpointInstance(); + } catch (Exception e) { + logger.error(e); + return; + } + channel.attr(POJO_KEY).set(implement); + Session session = new Session(channel); + channel.attr(SESSION_KEY).set(session); + Method beforeHandshake = methodMapping.getBeforeHandshake(); + if (beforeHandshake != null) { + try { + beforeHandshake.invoke(implement, methodMapping.getBeforeHandshakeArgs(channel, req)); + } catch (TypeMismatchException e) { + throw e; + } catch (Throwable t) { + logger.error(t); + } + } + } + + public void doOnOpen(Channel channel, FullHttpRequest req, String path) { + PojoMethodMapping methodMapping = getPojoMethodMapping(path, channel); + + Object implement = channel.attr(POJO_KEY).get(); + if (implement==null){ + try { + implement = methodMapping.getEndpointInstance(); + channel.attr(POJO_KEY).set(implement); + } catch (Exception e) { + logger.error(e); + return; + } + Session session = new Session(channel); + channel.attr(SESSION_KEY).set(session); + } + + Method onOpenMethod = methodMapping.getOnOpen(); + if (onOpenMethod != null) { + try { + onOpenMethod.invoke(implement, methodMapping.getOnOpenArgs(channel, req)); + } catch (TypeMismatchException e) { + throw e; + } catch (Throwable t) { + logger.error(t); + } + } + } + + public void doOnClose(Channel channel) { + Attribute attrPath = channel.attr(PATH_KEY); + PojoMethodMapping methodMapping = null; + if (pathMethodMappingMap.size() == 1) { + methodMapping = pathMethodMappingMap.values().iterator().next(); + } else { + String path = attrPath.get(); + methodMapping = pathMethodMappingMap.get(path); + if (methodMapping == null) { + return; + } + } + if (methodMapping.getOnClose() != null) { + if (!channel.hasAttr(SESSION_KEY)) { + return; + } + Object implement = channel.attr(POJO_KEY).get(); + try { + methodMapping.getOnClose().invoke(implement, + methodMapping.getOnCloseArgs(channel)); + } catch (Throwable t) { + logger.error(t); + } + } + } + + + public void doOnError(Channel channel, Throwable throwable) { + Attribute attrPath = channel.attr(PATH_KEY); + PojoMethodMapping methodMapping = null; + if (pathMethodMappingMap.size() == 1) { + methodMapping = pathMethodMappingMap.values().iterator().next(); + } else { + String path = attrPath.get(); + methodMapping = pathMethodMappingMap.get(path); + } + if (methodMapping.getOnError() != null) { + if (!channel.hasAttr(SESSION_KEY)) { + return; + } + Object implement = channel.attr(POJO_KEY).get(); + try { + Method method = methodMapping.getOnError(); + Object[] args = methodMapping.getOnErrorArgs(channel, throwable); + method.invoke(implement, args); + } catch (Throwable t) { + logger.error(t); + } + } + } + + public void doOnMessage(Channel channel, WebSocketFrame frame) { + Attribute attrPath = channel.attr(PATH_KEY); + PojoMethodMapping methodMapping = null; + if (pathMethodMappingMap.size() == 1) { + methodMapping = pathMethodMappingMap.values().iterator().next(); + } else { + String path = attrPath.get(); + methodMapping = pathMethodMappingMap.get(path); + } + if (methodMapping.getOnMessage() != null) { + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + Object implement = channel.attr(POJO_KEY).get(); + try { + methodMapping.getOnMessage().invoke(implement, methodMapping.getOnMessageArgs(channel, textFrame)); + } catch (Throwable t) { + logger.error(t); + } + } + } + + /** + * OnBinary接口执行 + * @param channel + * @param frame + * @return 返回byteBuf是否自动释放的标记,如果netty接口入参存在ByteBuf类型,那么不关闭,让框架外自行关闭 + */ + public boolean doOnBinary(Channel channel, WebSocketFrame frame) { + Attribute attrPath = channel.attr(PATH_KEY); + PojoMethodMapping methodMapping = null; + if (pathMethodMappingMap.size() == 1) { + methodMapping = pathMethodMappingMap.values().iterator().next(); + } else { + String path = attrPath.get(); + methodMapping = pathMethodMappingMap.get(path); + } + if (methodMapping.getOnBinary() != null) { + BinaryWebSocketFrame binaryWebSocketFrame = (BinaryWebSocketFrame) frame; + Object implement = channel.attr(POJO_KEY).get(); + try { + Object[] args = methodMapping.getOnBinaryArgs(channel, binaryWebSocketFrame); + methodMapping.getOnBinary().invoke(implement, args); + for(Object arg : args){ + if(arg instanceof ByteBuf||arg instanceof Channel){ + return false; + } + } + } catch (Throwable t) { + logger.error(t); + } + } + return true; + } + + public void doOnEvent(Channel channel, Object evt) { + Attribute attrPath = channel.attr(PATH_KEY); + PojoMethodMapping methodMapping = null; + if (pathMethodMappingMap.size() == 1) { + methodMapping = pathMethodMappingMap.values().iterator().next(); + } else { + String path = attrPath.get(); + methodMapping = pathMethodMappingMap.get(path); + } + if (methodMapping.getOnEvent() != null) { + if (!channel.hasAttr(SESSION_KEY)) { + return; + } + Object implement = channel.attr(POJO_KEY).get(); + try { + methodMapping.getOnEvent().invoke(implement, methodMapping.getOnEventArgs(channel, evt)); + } catch (Throwable t) { + logger.error(t); + } + } + } + + public String getHost() { + return config.getHost(); + } + + public int getPort() { + return config.getPort(); + } + + public Set getPathMatcherSet() { + return pathMatchers; + } + + public void addPathPojoMethodMapping(String path, PojoMethodMapping pojoMethodMapping) { + pathMethodMappingMap.put(path, pojoMethodMapping); + for (MethodArgumentResolver onOpenArgResolver : pojoMethodMapping.getOnOpenArgResolvers()) { + if (onOpenArgResolver instanceof PathVariableMethodArgumentResolver || onOpenArgResolver instanceof PathVariableMapMethodArgumentResolver) { + pathMatchers.add(new AntPathMatcherWrapper(path)); + return; + } + } + pathMatchers.add(new DefaultPathMatcher(path)); + } + + private PojoMethodMapping getPojoMethodMapping(String path, Channel channel) { + PojoMethodMapping methodMapping; + if (pathMethodMappingMap.size() == 1) { + methodMapping = pathMethodMappingMap.values().iterator().next(); + } else { + Attribute attrPath = channel.attr(PATH_KEY); + attrPath.set(path); + methodMapping = pathMethodMappingMap.get(path); + if (methodMapping == null) { + throw new RuntimeException("path " + path + " is not in pathMethodMappingMap "); + } + } + return methodMapping; + } +} diff --git a/src/main/java/org/yeauty/pojo/PojoMethodMapping.java b/src/main/java/org/yeauty/pojo/PojoMethodMapping.java index 489b977f7a6cc0cc66bef0ca93d756b634bd5ba8..4a75abeabba516a344ed774ffc47b82f7c085ea8 100755 --- a/src/main/java/org/yeauty/pojo/PojoMethodMapping.java +++ b/src/main/java/org/yeauty/pojo/PojoMethodMapping.java @@ -1,371 +1,376 @@ -package org.yeauty.pojo; - -import io.netty.channel.Channel; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; -import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; -import org.springframework.beans.factory.support.AbstractBeanFactory; -import org.springframework.context.ApplicationContext; -import org.springframework.core.DefaultParameterNameDiscoverer; -import org.springframework.core.MethodParameter; -import org.springframework.core.ParameterNameDiscoverer; -import org.yeauty.annotation.*; -import org.yeauty.exception.DeploymentException; -import org.yeauty.support.*; - -import java.lang.annotation.Annotation; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class PojoMethodMapping { - - private static final ParameterNameDiscoverer parameterNameDiscoverer = new DefaultParameterNameDiscoverer(); - - private final Method beforeHandshake; - private final Method onOpen; - private final Method onClose; - private final Method onError; - private final Method onMessage; - private final Method onBinary; - private final Method onEvent; - private final MethodParameter[] beforeHandshakeParameters; - private final MethodParameter[] onOpenParameters; - private final MethodParameter[] onCloseParameters; - private final MethodParameter[] onErrorParameters; - private final MethodParameter[] onMessageParameters; - private final MethodParameter[] onBinaryParameters; - private final MethodParameter[] onEventParameters; - private final MethodArgumentResolver[] beforeHandshakeArgResolvers; - private final MethodArgumentResolver[] onOpenArgResolvers; - private final MethodArgumentResolver[] onCloseArgResolvers; - private final MethodArgumentResolver[] onErrorArgResolvers; - private final MethodArgumentResolver[] onMessageArgResolvers; - private final MethodArgumentResolver[] onBinaryArgResolvers; - private final MethodArgumentResolver[] onEventArgResolvers; - private final Class pojoClazz; - private final ApplicationContext applicationContext; - private final AbstractBeanFactory beanFactory; - - public PojoMethodMapping(Class pojoClazz, ApplicationContext context, AbstractBeanFactory beanFactory) throws DeploymentException { - this.applicationContext = context; - this.pojoClazz = pojoClazz; - this.beanFactory = beanFactory; - Method handshake = null; - Method open = null; - Method close = null; - Method error = null; - Method message = null; - Method binary = null; - Method event = null; - Method[] pojoClazzMethods = null; - Class currentClazz = pojoClazz; - while (!currentClazz.equals(Object.class)) { - Method[] currentClazzMethods = currentClazz.getDeclaredMethods(); - if (currentClazz == pojoClazz) { - pojoClazzMethods = currentClazzMethods; - } - for (Method method : currentClazzMethods) { - if (method.getAnnotation(BeforeHandshake.class) != null) { - checkPublic(method); - if (handshake == null) { - handshake = method; - } else { - if (currentClazz == pojoClazz || - !isMethodOverride(handshake, method)) { - // Duplicate annotation - throw new DeploymentException( - "pojoMethodMapping.duplicateAnnotation BeforeHandshake"); - } - } - } else if (method.getAnnotation(OnOpen.class) != null) { - checkPublic(method); - if (open == null) { - open = method; - } else { - if (currentClazz == pojoClazz || - !isMethodOverride(open, method)) { - // Duplicate annotation - throw new DeploymentException( - "pojoMethodMapping.duplicateAnnotation OnOpen"); - } - } - } else if (method.getAnnotation(OnClose.class) != null) { - checkPublic(method); - if (close == null) { - close = method; - } else { - if (currentClazz == pojoClazz || - !isMethodOverride(close, method)) { - // Duplicate annotation - throw new DeploymentException( - "pojoMethodMapping.duplicateAnnotation OnClose"); - } - } - } else if (method.getAnnotation(OnError.class) != null) { - checkPublic(method); - if (error == null) { - error = method; - } else { - if (currentClazz == pojoClazz || - !isMethodOverride(error, method)) { - // Duplicate annotation - throw new DeploymentException( - "pojoMethodMapping.duplicateAnnotation OnError"); - } - } - } else if (method.getAnnotation(OnMessage.class) != null) { - checkPublic(method); - if (message == null) { - message = method; - } else { - if (currentClazz == pojoClazz || - !isMethodOverride(message, method)) { - // Duplicate annotation - throw new DeploymentException( - "pojoMethodMapping.duplicateAnnotation onMessage"); - } - } - } else if (method.getAnnotation(OnBinary.class) != null) { - checkPublic(method); - if (binary == null) { - binary = method; - } else { - if (currentClazz == pojoClazz || - !isMethodOverride(binary, method)) { - // Duplicate annotation - throw new DeploymentException( - "pojoMethodMapping.duplicateAnnotation OnBinary"); - } - } - } else if (method.getAnnotation(OnEvent.class) != null) { - checkPublic(method); - if (event == null) { - event = method; - } else { - if (currentClazz == pojoClazz || - !isMethodOverride(event, method)) { - // Duplicate annotation - throw new DeploymentException( - "pojoMethodMapping.duplicateAnnotation OnEvent"); - } - } - } else { - // Method not annotated - } - } - currentClazz = currentClazz.getSuperclass(); - } - // If the methods are not on pojoClazz and they are overridden - // by a non annotated method in pojoClazz, they should be ignored - if (handshake != null && handshake.getDeclaringClass() != pojoClazz) { - if (isOverridenWithoutAnnotation(pojoClazzMethods, handshake, BeforeHandshake.class)) { - handshake = null; - } - } - if (open != null && open.getDeclaringClass() != pojoClazz) { - if (isOverridenWithoutAnnotation(pojoClazzMethods, open, OnOpen.class)) { - open = null; - } - } - if (close != null && close.getDeclaringClass() != pojoClazz) { - if (isOverridenWithoutAnnotation(pojoClazzMethods, close, OnClose.class)) { - close = null; - } - } - if (error != null && error.getDeclaringClass() != pojoClazz) { - if (isOverridenWithoutAnnotation(pojoClazzMethods, error, OnError.class)) { - error = null; - } - } - if (message != null && message.getDeclaringClass() != pojoClazz) { - if (isOverridenWithoutAnnotation(pojoClazzMethods, message, OnMessage.class)) { - message = null; - } - } - if (binary != null && binary.getDeclaringClass() != pojoClazz) { - if (isOverridenWithoutAnnotation(pojoClazzMethods, binary, OnBinary.class)) { - binary = null; - } - } - if (event != null && event.getDeclaringClass() != pojoClazz) { - if (isOverridenWithoutAnnotation(pojoClazzMethods, event, OnEvent.class)) { - event = null; - } - } - - this.beforeHandshake = handshake; - this.onOpen = open; - this.onClose = close; - this.onError = error; - this.onMessage = message; - this.onBinary = binary; - this.onEvent = event; - beforeHandshakeParameters = getParameters(beforeHandshake); - onOpenParameters = getParameters(onOpen); - onCloseParameters = getParameters(onClose); - onMessageParameters = getParameters(onMessage); - onErrorParameters = getParameters(onError); - onBinaryParameters = getParameters(onBinary); - onEventParameters = getParameters(onEvent); - beforeHandshakeArgResolvers = getResolvers(beforeHandshakeParameters); - onOpenArgResolvers = getResolvers(onOpenParameters); - onCloseArgResolvers = getResolvers(onCloseParameters); - onMessageArgResolvers = getResolvers(onMessageParameters); - onErrorArgResolvers = getResolvers(onErrorParameters); - onBinaryArgResolvers = getResolvers(onBinaryParameters); - onEventArgResolvers = getResolvers(onEventParameters); - } - - private void checkPublic(Method m) throws DeploymentException { - if (!Modifier.isPublic(m.getModifiers())) { - throw new DeploymentException( - "pojoMethodMapping.methodNotPublic " + m.getName()); - } - } - - private boolean isMethodOverride(Method method1, Method method2) { - return (method1.getName().equals(method2.getName()) - && method1.getReturnType().equals(method2.getReturnType()) - && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes())); - } - - private boolean isOverridenWithoutAnnotation(Method[] methods, Method superclazzMethod, Class annotation) { - for (Method method : methods) { - if (isMethodOverride(method, superclazzMethod) - && (method.getAnnotation(annotation) == null)) { - return true; - } - } - return false; - } - - Object getEndpointInstance() throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException { - Object implement = pojoClazz.getDeclaredConstructor().newInstance(); - AutowiredAnnotationBeanPostProcessor postProcessor = applicationContext.getBean(AutowiredAnnotationBeanPostProcessor.class); - postProcessor.postProcessPropertyValues(null, null, implement, null); - return implement; - } - - Method getBeforeHandshake() { - return beforeHandshake; - } - - Object[] getBeforeHandshakeArgs(Channel channel, FullHttpRequest req) throws Exception { - return getMethodArgumentValues(channel, req, beforeHandshakeParameters, beforeHandshakeArgResolvers); - } - - Method getOnOpen() { - return onOpen; - } - - Object[] getOnOpenArgs(Channel channel, FullHttpRequest req) throws Exception { - return getMethodArgumentValues(channel, req, onOpenParameters, onOpenArgResolvers); - } - - MethodArgumentResolver[] getOnOpenArgResolvers() { - return onOpenArgResolvers; - } - - Method getOnClose() { - return onClose; - } - - Object[] getOnCloseArgs(Channel channel) throws Exception { - return getMethodArgumentValues(channel, null, onCloseParameters, onCloseArgResolvers); - } - - Method getOnError() { - return onError; - } - - Object[] getOnErrorArgs(Channel channel, Throwable throwable) throws Exception { - return getMethodArgumentValues(channel, throwable, onErrorParameters, onErrorArgResolvers); - } - - Method getOnMessage() { - return onMessage; - } - - Object[] getOnMessageArgs(Channel channel, TextWebSocketFrame textWebSocketFrame) throws Exception { - return getMethodArgumentValues(channel, textWebSocketFrame, onMessageParameters, onMessageArgResolvers); - } - - Method getOnBinary() { - return onBinary; - } - - Object[] getOnBinaryArgs(Channel channel, BinaryWebSocketFrame binaryWebSocketFrame) throws Exception { - return getMethodArgumentValues(channel, binaryWebSocketFrame, onBinaryParameters, onBinaryArgResolvers); - } - - Method getOnEvent() { - return onEvent; - } - - Object[] getOnEventArgs(Channel channel, Object evt) throws Exception { - return getMethodArgumentValues(channel, evt, onEventParameters, onEventArgResolvers); - } - - private Object[] getMethodArgumentValues(Channel channel, Object object, MethodParameter[] parameters, MethodArgumentResolver[] resolvers) throws Exception { - Object[] objects = new Object[parameters.length]; - for (int i = 0; i < parameters.length; i++) { - MethodParameter parameter = parameters[i]; - MethodArgumentResolver resolver = resolvers[i]; - Object arg = resolver.resolveArgument(parameter, channel, object); - objects[i] = arg; - } - return objects; - } - - private MethodArgumentResolver[] getResolvers(MethodParameter[] parameters) throws DeploymentException { - MethodArgumentResolver[] methodArgumentResolvers = new MethodArgumentResolver[parameters.length]; - List resolvers = getDefaultResolvers(); - for (int i = 0; i < parameters.length; i++) { - MethodParameter parameter = parameters[i]; - for (MethodArgumentResolver resolver : resolvers) { - if (resolver.supportsParameter(parameter)) { - methodArgumentResolvers[i] = resolver; - break; - } - } - if (methodArgumentResolvers[i] == null) { - throw new DeploymentException("pojoMethodMapping.paramClassIncorrect parameter name : " + parameter.getParameterName()); - } - } - return methodArgumentResolvers; - } - - private List getDefaultResolvers() { - List resolvers = new ArrayList<>(); - resolvers.add(new SessionMethodArgumentResolver()); - resolvers.add(new HttpHeadersMethodArgumentResolver()); - resolvers.add(new TextMethodArgumentResolver()); - resolvers.add(new ThrowableMethodArgumentResolver()); - resolvers.add(new ByteMethodArgumentResolver()); - resolvers.add(new RequestParamMapMethodArgumentResolver()); - resolvers.add(new RequestParamMethodArgumentResolver(beanFactory)); - resolvers.add(new PathVariableMapMethodArgumentResolver()); - resolvers.add(new PathVariableMethodArgumentResolver(beanFactory)); - resolvers.add(new EventMethodArgumentResolver(beanFactory)); - return resolvers; - } - - private static MethodParameter[] getParameters(Method m) { - if (m == null) { - return new MethodParameter[0]; - } - int count = m.getParameterCount(); - MethodParameter[] result = new MethodParameter[count]; - for (int i = 0; i < count; i++) { - MethodParameter methodParameter = new MethodParameter(m, i); - methodParameter.initParameterNameDiscovery(parameterNameDiscoverer); - result[i] = methodParameter; - } - return result; - } +package org.yeauty.pojo; + +import io.netty.channel.Channel; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; +import org.springframework.beans.factory.support.AbstractBeanFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.core.DefaultParameterNameDiscoverer; +import org.springframework.core.MethodParameter; +import org.springframework.core.ParameterNameDiscoverer; +import org.yeauty.annotation.*; +import org.yeauty.exception.DeploymentException; +import org.yeauty.support.*; + +import java.lang.annotation.Annotation; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class PojoMethodMapping { + + private static final ParameterNameDiscoverer parameterNameDiscoverer = new DefaultParameterNameDiscoverer(); + + private final Method beforeHandshake; + private final Method onOpen; + private final Method onClose; + private final Method onError; + private final Method onMessage; + private final Method onBinary; + private final Method onEvent; + private final MethodParameter[] beforeHandshakeParameters; + private final MethodParameter[] onOpenParameters; + private final MethodParameter[] onCloseParameters; + private final MethodParameter[] onErrorParameters; + private final MethodParameter[] onMessageParameters; + private final MethodParameter[] onBinaryParameters; + private final MethodParameter[] onEventParameters; + private final MethodArgumentResolver[] beforeHandshakeArgResolvers; + private final MethodArgumentResolver[] onOpenArgResolvers; + private final MethodArgumentResolver[] onCloseArgResolvers; + private final MethodArgumentResolver[] onErrorArgResolvers; + private final MethodArgumentResolver[] onMessageArgResolvers; + private final MethodArgumentResolver[] onBinaryArgResolvers; + private final MethodArgumentResolver[] onEventArgResolvers; + private final Class pojoClazz; + private final ApplicationContext applicationContext; + private final AbstractBeanFactory beanFactory; + + public PojoMethodMapping(Class pojoClazz, ApplicationContext context, AbstractBeanFactory beanFactory) throws DeploymentException { + this.applicationContext = context; + this.pojoClazz = pojoClazz; + this.beanFactory = beanFactory; + Method handshake = null; + Method open = null; + Method close = null; + Method error = null; + Method message = null; + Method binary = null; + Method event = null; + Method[] pojoClazzMethods = null; + Class currentClazz = pojoClazz; + while (!currentClazz.equals(Object.class)) { + Method[] currentClazzMethods = currentClazz.getDeclaredMethods(); + if (currentClazz == pojoClazz) { + pojoClazzMethods = currentClazzMethods; + } + for (Method method : currentClazzMethods) { + if (method.getAnnotation(BeforeHandshake.class) != null) { + checkPublic(method); + if (handshake == null) { + handshake = method; + } else { + if (currentClazz == pojoClazz || + !isMethodOverride(handshake, method)) { + // Duplicate annotation + throw new DeploymentException( + "pojoMethodMapping.duplicateAnnotation BeforeHandshake"); + } + } + } else if (method.getAnnotation(OnOpen.class) != null) { + checkPublic(method); + if (open == null) { + open = method; + } else { + if (currentClazz == pojoClazz || + !isMethodOverride(open, method)) { + // Duplicate annotation + throw new DeploymentException( + "pojoMethodMapping.duplicateAnnotation OnOpen"); + } + } + } else if (method.getAnnotation(OnClose.class) != null) { + checkPublic(method); + if (close == null) { + close = method; + } else { + if (currentClazz == pojoClazz || + !isMethodOverride(close, method)) { + // Duplicate annotation + throw new DeploymentException( + "pojoMethodMapping.duplicateAnnotation OnClose"); + } + } + } else if (method.getAnnotation(OnError.class) != null) { + checkPublic(method); + if (error == null) { + error = method; + } else { + if (currentClazz == pojoClazz || + !isMethodOverride(error, method)) { + // Duplicate annotation + throw new DeploymentException( + "pojoMethodMapping.duplicateAnnotation OnError"); + } + } + } else if (method.getAnnotation(OnMessage.class) != null) { + checkPublic(method); + if (message == null) { + message = method; + } else { + if (currentClazz == pojoClazz || + !isMethodOverride(message, method)) { + // Duplicate annotation + throw new DeploymentException( + "pojoMethodMapping.duplicateAnnotation onMessage"); + } + } + } else if (method.getAnnotation(OnBinary.class) != null) { + checkPublic(method); + if (binary == null) { + binary = method; + } else { + if (currentClazz == pojoClazz || + !isMethodOverride(binary, method)) { + // Duplicate annotation + throw new DeploymentException( + "pojoMethodMapping.duplicateAnnotation OnBinary"); + } + } + } else if (method.getAnnotation(OnEvent.class) != null) { + checkPublic(method); + if (event == null) { + event = method; + } else { + if (currentClazz == pojoClazz || + !isMethodOverride(event, method)) { + // Duplicate annotation + throw new DeploymentException( + "pojoMethodMapping.duplicateAnnotation OnEvent"); + } + } + } else { + // Method not annotated + } + } + currentClazz = currentClazz.getSuperclass(); + } + // If the methods are not on pojoClazz and they are overridden + // by a non annotated method in pojoClazz, they should be ignored + if (handshake != null && handshake.getDeclaringClass() != pojoClazz) { + if (isOverridenWithoutAnnotation(pojoClazzMethods, handshake, BeforeHandshake.class)) { + handshake = null; + } + } + if (open != null && open.getDeclaringClass() != pojoClazz) { + if (isOverridenWithoutAnnotation(pojoClazzMethods, open, OnOpen.class)) { + open = null; + } + } + if (close != null && close.getDeclaringClass() != pojoClazz) { + if (isOverridenWithoutAnnotation(pojoClazzMethods, close, OnClose.class)) { + close = null; + } + } + if (error != null && error.getDeclaringClass() != pojoClazz) { + if (isOverridenWithoutAnnotation(pojoClazzMethods, error, OnError.class)) { + error = null; + } + } + if (message != null && message.getDeclaringClass() != pojoClazz) { + if (isOverridenWithoutAnnotation(pojoClazzMethods, message, OnMessage.class)) { + message = null; + } + } + if (binary != null && binary.getDeclaringClass() != pojoClazz) { + if (isOverridenWithoutAnnotation(pojoClazzMethods, binary, OnBinary.class)) { + binary = null; + } + } + if (event != null && event.getDeclaringClass() != pojoClazz) { + if (isOverridenWithoutAnnotation(pojoClazzMethods, event, OnEvent.class)) { + event = null; + } + } + + this.beforeHandshake = handshake; + this.onOpen = open; + this.onClose = close; + this.onError = error; + this.onMessage = message; + this.onBinary = binary; + this.onEvent = event; + beforeHandshakeParameters = getParameters(beforeHandshake); + onOpenParameters = getParameters(onOpen); + onCloseParameters = getParameters(onClose); + onMessageParameters = getParameters(onMessage); + onErrorParameters = getParameters(onError); + onBinaryParameters = getParameters(onBinary); + onEventParameters = getParameters(onEvent); + beforeHandshakeArgResolvers = getResolvers(beforeHandshakeParameters); + onOpenArgResolvers = getResolvers(onOpenParameters); + onCloseArgResolvers = getResolvers(onCloseParameters); + onMessageArgResolvers = getResolvers(onMessageParameters); + onErrorArgResolvers = getResolvers(onErrorParameters); + onBinaryArgResolvers = getResolvers(onBinaryParameters); + onEventArgResolvers = getResolvers(onEventParameters); + } + + private void checkPublic(Method m) throws DeploymentException { + if (!Modifier.isPublic(m.getModifiers())) { + throw new DeploymentException( + "pojoMethodMapping.methodNotPublic " + m.getName()); + } + } + + private boolean isMethodOverride(Method method1, Method method2) { + return (method1.getName().equals(method2.getName()) + && method1.getReturnType().equals(method2.getReturnType()) + && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes())); + } + + private boolean isOverridenWithoutAnnotation(Method[] methods, Method superclazzMethod, Class annotation) { + for (Method method : methods) { + if (isMethodOverride(method, superclazzMethod) + && (method.getAnnotation(annotation) == null)) { + return true; + } + } + return false; + } + + Object getEndpointInstance() throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException { + Object implement = pojoClazz.getDeclaredConstructor().newInstance(); + AutowiredAnnotationBeanPostProcessor postProcessor = applicationContext.getBean(AutowiredAnnotationBeanPostProcessor.class); + postProcessor.postProcessPropertyValues(null, null, implement, null); + return implement; + } + + Method getBeforeHandshake() { + return beforeHandshake; + } + + Object[] getBeforeHandshakeArgs(Channel channel, FullHttpRequest req) throws Exception { + return getMethodArgumentValues(channel, req, beforeHandshakeParameters, beforeHandshakeArgResolvers); + } + + Method getOnOpen() { + return onOpen; + } + + Object[] getOnOpenArgs(Channel channel, FullHttpRequest req) throws Exception { + return getMethodArgumentValues(channel, req, onOpenParameters, onOpenArgResolvers); + } + + MethodArgumentResolver[] getOnOpenArgResolvers() { + return onOpenArgResolvers; + } + + Method getOnClose() { + return onClose; + } + + Object[] getOnCloseArgs(Channel channel) throws Exception { + return getMethodArgumentValues(channel, null, onCloseParameters, onCloseArgResolvers); + } + + Method getOnError() { + return onError; + } + + Object[] getOnErrorArgs(Channel channel, Throwable throwable) throws Exception { + return getMethodArgumentValues(channel, throwable, onErrorParameters, onErrorArgResolvers); + } + + Method getOnMessage() { + return onMessage; + } + + Object[] getOnMessageArgs(Channel channel, TextWebSocketFrame textWebSocketFrame) throws Exception { + return getMethodArgumentValues(channel, textWebSocketFrame, onMessageParameters, onMessageArgResolvers); + } + + Method getOnBinary() { + return onBinary; + } + + Object[] getOnBinaryArgs(Channel channel, BinaryWebSocketFrame binaryWebSocketFrame) throws Exception { + return getMethodArgumentValues(channel, binaryWebSocketFrame, onBinaryParameters, onBinaryArgResolvers); + } + + Method getOnEvent() { + return onEvent; + } + + Object[] getOnEventArgs(Channel channel, Object evt) throws Exception { + return getMethodArgumentValues(channel, evt, onEventParameters, onEventArgResolvers); + } + + private Object[] getMethodArgumentValues(Channel channel, Object object, MethodParameter[] parameters, MethodArgumentResolver[] resolvers) throws Exception { + Object[] objects = new Object[parameters.length]; + for (int i = 0; i < parameters.length; i++) { + MethodParameter parameter = parameters[i]; + MethodArgumentResolver resolver = resolvers[i]; + Object arg = resolver.resolveArgument(parameter, channel, object); + objects[i] = arg; + } + return objects; + } + + private MethodArgumentResolver[] getResolvers(MethodParameter[] parameters) throws DeploymentException { + MethodArgumentResolver[] methodArgumentResolvers = new MethodArgumentResolver[parameters.length]; + List resolvers = getDefaultResolvers(); + for (int i = 0; i < parameters.length; i++) { + MethodParameter parameter = parameters[i]; + for (MethodArgumentResolver resolver : resolvers) { + if (resolver.supportsParameter(parameter)) { + methodArgumentResolvers[i] = resolver; + break; + } + } + if (methodArgumentResolvers[i] == null) { + throw new DeploymentException("pojoMethodMapping.paramClassIncorrect parameter name : " + parameter.getParameterName()); + } + } + return methodArgumentResolvers; + } + + protected List getDefaultResolvers() {//可继承,重写 + List resolvers = new ArrayList<>(); + resolvers.add(new PathVariableMapMethodArgumentResolver());//优先判断注解 + resolvers.add(new PathVariableMethodArgumentResolver(beanFactory)); + resolvers.add(new SessionMethodArgumentResolver()); + resolvers.add(new HttpHeadersMethodArgumentResolver()); + resolvers.add(new TextMethodArgumentResolver()); + resolvers.add(new ThrowableMethodArgumentResolver()); + resolvers.add(new ByteMethodArgumentResolver()); + resolvers.add(new RequestParamMapMethodArgumentResolver()); + resolvers.add(new RequestParamMethodArgumentResolver(beanFactory)); +// resolvers.add(new PathVariableMapMethodArgumentResolver()); +// resolvers.add(new PathVariableMethodArgumentResolver(beanFactory)); + resolvers.add(new EventMethodArgumentResolver(beanFactory)); + resolvers.add(new ByteBufMethodArgumentResolver());//支持对直接内存操作 + resolvers.add(new ChannelMethodArgumentResolver());//支持Channel + + return resolvers; + } + + private static MethodParameter[] getParameters(Method m) { + if (m == null) { + return new MethodParameter[0]; + } + int count = m.getParameterCount(); + MethodParameter[] result = new MethodParameter[count]; + for (int i = 0; i < count; i++) { + MethodParameter methodParameter = new MethodParameter(m, i); + methodParameter.initParameterNameDiscovery(parameterNameDiscoverer); + result[i] = methodParameter; + } + return result; + } } \ No newline at end of file diff --git a/src/main/java/org/yeauty/standard/ServerEndpointExporter.java b/src/main/java/org/yeauty/standard/ServerEndpointExporter.java index ef85bb94edd9f8d4f9daa67cc3083de648eee435..d99f892730bf7007dda5f95e7a5c4182f43f4263 100755 --- a/src/main/java/org/yeauty/standard/ServerEndpointExporter.java +++ b/src/main/java/org/yeauty/standard/ServerEndpointExporter.java @@ -1,253 +1,267 @@ -package org.yeauty.standard; - -import org.springframework.beans.TypeConverter; -import org.springframework.beans.TypeMismatchException; -import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.BeanFactoryAware; -import org.springframework.beans.factory.SmartInitializingSingleton; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.config.BeanExpressionContext; -import org.springframework.beans.factory.config.BeanExpressionResolver; -import org.springframework.beans.factory.support.AbstractBeanFactory; -import org.springframework.beans.factory.support.BeanDefinitionRegistry; -import org.springframework.boot.autoconfigure.SpringBootApplication; -import org.springframework.context.ApplicationContext; -import org.springframework.context.ResourceLoaderAware; -import org.springframework.context.support.ApplicationObjectSupport; -import org.springframework.core.annotation.AnnotatedElementUtils; -import org.springframework.core.annotation.AnnotationUtils; -import org.springframework.core.env.Environment; -import org.springframework.core.io.ResourceLoader; -import org.springframework.util.ClassUtils; -import org.yeauty.annotation.EnableWebSocket; -import org.yeauty.annotation.ServerEndpoint; -import org.yeauty.exception.DeploymentException; -import org.yeauty.pojo.PojoEndpointServer; -import org.yeauty.pojo.PojoMethodMapping; - -import javax.net.ssl.SSLException; -import java.net.InetSocketAddress; -import java.util.*; - -/** - * @author Yeauty - */ -public class ServerEndpointExporter extends ApplicationObjectSupport implements SmartInitializingSingleton, BeanFactoryAware, ResourceLoaderAware { - - @Autowired - Environment environment; - - private AbstractBeanFactory beanFactory; - - private ResourceLoader resourceLoader; - - private final Map addressWebsocketServerMap = new HashMap<>(); - - @Override - public void afterSingletonsInstantiated() { - registerEndpoints(); - } - - @Override - public void setBeanFactory(BeanFactory beanFactory) { - if (!(beanFactory instanceof AbstractBeanFactory)) { - throw new IllegalArgumentException( - "AutowiredAnnotationBeanPostProcessor requires a AbstractBeanFactory: " + beanFactory); - } - this.beanFactory = (AbstractBeanFactory) beanFactory; - } - - protected void registerEndpoints() { - ApplicationContext context = getApplicationContext(); - - scanPackage(context); - - String[] endpointBeanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class); - Set> endpointClasses = new LinkedHashSet<>(); - for (String beanName : endpointBeanNames) { - endpointClasses.add(context.getType(beanName)); - } - - for (Class endpointClass : endpointClasses) { - if (ClassUtils.isCglibProxyClass(endpointClass)) { - registerEndpoint(endpointClass.getSuperclass()); - } else { - registerEndpoint(endpointClass); - } - } - - init(); - } - - private void scanPackage(ApplicationContext context) { - String[] basePackages = null; - - String[] enableWebSocketBeanNames = context.getBeanNamesForAnnotation(EnableWebSocket.class); - if (enableWebSocketBeanNames.length != 0) { - for (String enableWebSocketBeanName : enableWebSocketBeanNames) { - Object enableWebSocketBean = context.getBean(enableWebSocketBeanName); - EnableWebSocket enableWebSocket = AnnotationUtils.findAnnotation(enableWebSocketBean.getClass(), EnableWebSocket.class); - assert enableWebSocket != null; - if (enableWebSocket.scanBasePackages().length != 0) { - basePackages = enableWebSocket.scanBasePackages(); - break; - } - } - } - - // use @SpringBootApplication package - if (basePackages == null) { - String[] springBootApplicationBeanName = context.getBeanNamesForAnnotation(SpringBootApplication.class); - Object springBootApplicationBean = context.getBean(springBootApplicationBeanName[0]); - SpringBootApplication springBootApplication = AnnotationUtils.findAnnotation(springBootApplicationBean.getClass(), SpringBootApplication.class); - assert springBootApplication != null; - if (springBootApplication.scanBasePackages().length != 0) { - basePackages = springBootApplication.scanBasePackages(); - } else { - String packageName = ClassUtils.getPackageName(springBootApplicationBean.getClass().getName()); - basePackages = new String[1]; - basePackages[0] = packageName; - } - } - - EndpointClassPathScanner scanHandle = new EndpointClassPathScanner((BeanDefinitionRegistry) context, false); - if (resourceLoader != null) { - scanHandle.setResourceLoader(resourceLoader); - } - - for (String basePackage : basePackages) { - scanHandle.doScan(basePackage); - } - } - - private void init() { - for (Map.Entry entry : addressWebsocketServerMap.entrySet()) { - WebsocketServer websocketServer = entry.getValue(); - try { - websocketServer.init(); - PojoEndpointServer pojoEndpointServer = websocketServer.getPojoEndpointServer(); - StringJoiner stringJoiner = new StringJoiner(","); - pojoEndpointServer.getPathMatcherSet().forEach(pathMatcher -> stringJoiner.add("'" + pathMatcher.getPattern() + "'")); - logger.info(String.format("\033[34mNetty WebSocket started on port: %s with context path(s): %s .\033[0m", pojoEndpointServer.getPort(), stringJoiner.toString())); - } catch (InterruptedException e) { - logger.error(String.format("websocket [%s] init fail", entry.getKey()), e); - } catch (SSLException e) { - logger.error(String.format("websocket [%s] ssl create fail", entry.getKey()), e); - - } - } - } - - private void registerEndpoint(Class endpointClass) { - ServerEndpoint annotation = AnnotatedElementUtils.findMergedAnnotation(endpointClass, ServerEndpoint.class); - if (annotation == null) { - throw new IllegalStateException("missingAnnotation ServerEndpoint"); - } - ServerEndpointConfig serverEndpointConfig = buildConfig(annotation); - - ApplicationContext context = getApplicationContext(); - PojoMethodMapping pojoMethodMapping = null; - try { - pojoMethodMapping = new PojoMethodMapping(endpointClass, context, beanFactory); - } catch (DeploymentException e) { - throw new IllegalStateException("Failed to register ServerEndpointConfig: " + serverEndpointConfig, e); - } - - InetSocketAddress inetSocketAddress = new InetSocketAddress(serverEndpointConfig.getHost(), serverEndpointConfig.getPort()); - String path = resolveAnnotationValue(annotation.value(), String.class, "path"); - - WebsocketServer websocketServer = addressWebsocketServerMap.get(inetSocketAddress); - if (websocketServer == null) { - PojoEndpointServer pojoEndpointServer = new PojoEndpointServer(pojoMethodMapping, serverEndpointConfig, path); - websocketServer = new WebsocketServer(pojoEndpointServer, serverEndpointConfig); - addressWebsocketServerMap.put(inetSocketAddress, websocketServer); - } else { - websocketServer.getPojoEndpointServer().addPathPojoMethodMapping(path, pojoMethodMapping); - } - } - - private ServerEndpointConfig buildConfig(ServerEndpoint annotation) { - String host = resolveAnnotationValue(annotation.host(), String.class, "host"); - int port = resolveAnnotationValue(annotation.port(), Integer.class, "port"); - String path = resolveAnnotationValue(annotation.value(), String.class, "value"); - int bossLoopGroupThreads = resolveAnnotationValue(annotation.bossLoopGroupThreads(), Integer.class, "bossLoopGroupThreads"); - int workerLoopGroupThreads = resolveAnnotationValue(annotation.workerLoopGroupThreads(), Integer.class, "workerLoopGroupThreads"); - boolean useCompressionHandler = resolveAnnotationValue(annotation.useCompressionHandler(), Boolean.class, "useCompressionHandler"); - - int optionConnectTimeoutMillis = resolveAnnotationValue(annotation.optionConnectTimeoutMillis(), Integer.class, "optionConnectTimeoutMillis"); - int optionSoBacklog = resolveAnnotationValue(annotation.optionSoBacklog(), Integer.class, "optionSoBacklog"); - - int childOptionWriteSpinCount = resolveAnnotationValue(annotation.childOptionWriteSpinCount(), Integer.class, "childOptionWriteSpinCount"); - int childOptionWriteBufferHighWaterMark = resolveAnnotationValue(annotation.childOptionWriteBufferHighWaterMark(), Integer.class, "childOptionWriteBufferHighWaterMark"); - int childOptionWriteBufferLowWaterMark = resolveAnnotationValue(annotation.childOptionWriteBufferLowWaterMark(), Integer.class, "childOptionWriteBufferLowWaterMark"); - int childOptionSoRcvbuf = resolveAnnotationValue(annotation.childOptionSoRcvbuf(), Integer.class, "childOptionSoRcvbuf"); - int childOptionSoSndbuf = resolveAnnotationValue(annotation.childOptionSoSndbuf(), Integer.class, "childOptionSoSndbuf"); - boolean childOptionTcpNodelay = resolveAnnotationValue(annotation.childOptionTcpNodelay(), Boolean.class, "childOptionTcpNodelay"); - boolean childOptionSoKeepalive = resolveAnnotationValue(annotation.childOptionSoKeepalive(), Boolean.class, "childOptionSoKeepalive"); - int childOptionSoLinger = resolveAnnotationValue(annotation.childOptionSoLinger(), Integer.class, "childOptionSoLinger"); - boolean childOptionAllowHalfClosure = resolveAnnotationValue(annotation.childOptionAllowHalfClosure(), Boolean.class, "childOptionAllowHalfClosure"); - - int readerIdleTimeSeconds = resolveAnnotationValue(annotation.readerIdleTimeSeconds(), Integer.class, "readerIdleTimeSeconds"); - int writerIdleTimeSeconds = resolveAnnotationValue(annotation.writerIdleTimeSeconds(), Integer.class, "writerIdleTimeSeconds"); - int allIdleTimeSeconds = resolveAnnotationValue(annotation.allIdleTimeSeconds(), Integer.class, "allIdleTimeSeconds"); - - int maxFramePayloadLength = resolveAnnotationValue(annotation.maxFramePayloadLength(), Integer.class, "maxFramePayloadLength"); - - boolean useEventExecutorGroup = resolveAnnotationValue(annotation.useEventExecutorGroup(), Boolean.class, "useEventExecutorGroup"); - int eventExecutorGroupThreads = resolveAnnotationValue(annotation.eventExecutorGroupThreads(), Integer.class, "eventExecutorGroupThreads"); - - String sslKeyPassword = resolveAnnotationValue(annotation.sslKeyPassword(), String.class, "sslKeyPassword"); - String sslKeyStore = resolveAnnotationValue(annotation.sslKeyStore(), String.class, "sslKeyStore"); - String sslKeyStorePassword = resolveAnnotationValue(annotation.sslKeyStorePassword(), String.class, "sslKeyStorePassword"); - String sslKeyStoreType = resolveAnnotationValue(annotation.sslKeyStoreType(), String.class, "sslKeyStoreType"); - String sslTrustStore = resolveAnnotationValue(annotation.sslTrustStore(), String.class, "sslTrustStore"); - String sslTrustStorePassword = resolveAnnotationValue(annotation.sslTrustStorePassword(), String.class, "sslTrustStorePassword"); - String sslTrustStoreType = resolveAnnotationValue(annotation.sslTrustStoreType(), String.class, "sslTrustStoreType"); - - String[] corsOrigins = annotation.corsOrigins(); - if (corsOrigins.length != 0) { - for (int i = 0; i < corsOrigins.length; i++) { - corsOrigins[i] = resolveAnnotationValue(corsOrigins[i], String.class, "corsOrigins"); - } - } - Boolean corsAllowCredentials = resolveAnnotationValue(annotation.corsAllowCredentials(), Boolean.class, "corsAllowCredentials"); - - ServerEndpointConfig serverEndpointConfig = new ServerEndpointConfig(host, port, bossLoopGroupThreads, workerLoopGroupThreads - , useCompressionHandler, optionConnectTimeoutMillis, optionSoBacklog, childOptionWriteSpinCount, childOptionWriteBufferHighWaterMark - , childOptionWriteBufferLowWaterMark, childOptionSoRcvbuf, childOptionSoSndbuf, childOptionTcpNodelay, childOptionSoKeepalive - , childOptionSoLinger, childOptionAllowHalfClosure, readerIdleTimeSeconds, writerIdleTimeSeconds, allIdleTimeSeconds - , maxFramePayloadLength, useEventExecutorGroup, eventExecutorGroupThreads - , sslKeyPassword, sslKeyStore, sslKeyStorePassword, sslKeyStoreType - , sslTrustStore, sslTrustStorePassword, sslTrustStoreType - , corsOrigins, corsAllowCredentials); - - return serverEndpointConfig; - } - - private T resolveAnnotationValue(Object value, Class requiredType, String paramName) { - if (value == null) { - return null; - } - TypeConverter typeConverter = beanFactory.getTypeConverter(); - - if (value instanceof String) { - String strVal = beanFactory.resolveEmbeddedValue((String) value); - BeanExpressionResolver beanExpressionResolver = beanFactory.getBeanExpressionResolver(); - if (beanExpressionResolver != null) { - value = beanExpressionResolver.evaluate(strVal, new BeanExpressionContext(beanFactory, null)); - } else { - value = strVal; - } - } - try { - return typeConverter.convertIfNecessary(value, requiredType); - } catch (TypeMismatchException e) { - throw new IllegalArgumentException("Failed to convert value of parameter '" + paramName + "' to required type '" + requiredType.getName() + "'"); - } - } - - @Override - public void setResourceLoader(ResourceLoader resourceLoader) { - this.resourceLoader = resourceLoader; - } -} +package org.yeauty.standard; + +import org.springframework.beans.TypeConverter; +import org.springframework.beans.TypeMismatchException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.BeanExpressionContext; +import org.springframework.beans.factory.config.BeanExpressionResolver; +import org.springframework.beans.factory.support.AbstractBeanFactory; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ResourceLoaderAware; +import org.springframework.context.support.ApplicationObjectSupport; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.env.Environment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.util.ClassUtils; +import org.yeauty.annotation.EnableWebSocket; +import org.yeauty.annotation.ServerEndpoint; +import org.yeauty.exception.DeploymentException; +import org.yeauty.pojo.PojoEndpointServer; +import org.yeauty.pojo.PojoMethodMapping; + +import javax.net.ssl.SSLException; +import java.net.InetSocketAddress; +import java.util.*; + +/** + * @author Yeauty + */ +public class ServerEndpointExporter extends ApplicationObjectSupport implements SmartInitializingSingleton, BeanFactoryAware, ResourceLoaderAware { + + @Autowired + Environment environment; + + protected AbstractBeanFactory beanFactory; + + private ResourceLoader resourceLoader; + + private final Map addressWebsocketServerMap = new HashMap<>(); + + @Override + public void afterSingletonsInstantiated() { + registerEndpoints(); + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) { + if (!(beanFactory instanceof AbstractBeanFactory)) { + throw new IllegalArgumentException( + "AutowiredAnnotationBeanPostProcessor requires a AbstractBeanFactory: " + beanFactory); + } + this.beanFactory = (AbstractBeanFactory) beanFactory; + } + + protected void registerEndpoints() { + ApplicationContext context = getApplicationContext(); + + scanPackage(context); + + String[] endpointBeanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class); + Set> endpointClasses = new LinkedHashSet<>(); + for (String beanName : endpointBeanNames) { + endpointClasses.add(context.getType(beanName)); + } + + for (Class endpointClass : endpointClasses) { + if (ClassUtils.isCglibProxyClass(endpointClass)) { + registerEndpoint(endpointClass.getSuperclass()); + } else { + registerEndpoint(endpointClass); + } + } + + init(); + } + + private void scanPackage(ApplicationContext context) { + String[] basePackages = null; + + String[] enableWebSocketBeanNames = context.getBeanNamesForAnnotation(EnableWebSocket.class); + if (enableWebSocketBeanNames.length != 0) { + for (String enableWebSocketBeanName : enableWebSocketBeanNames) { + Object enableWebSocketBean = context.getBean(enableWebSocketBeanName); + EnableWebSocket enableWebSocket = AnnotationUtils.findAnnotation(enableWebSocketBean.getClass(), EnableWebSocket.class); + assert enableWebSocket != null; + if (enableWebSocket.scanBasePackages().length != 0) { + basePackages = enableWebSocket.scanBasePackages(); + break; + } + } + } + + // use @SpringBootApplication package + if (basePackages == null) { + String[] springBootApplicationBeanName = context.getBeanNamesForAnnotation(SpringBootApplication.class); + Object springBootApplicationBean = context.getBean(springBootApplicationBeanName[0]); + SpringBootApplication springBootApplication = AnnotationUtils.findAnnotation(springBootApplicationBean.getClass(), SpringBootApplication.class); + assert springBootApplication != null; + if (springBootApplication.scanBasePackages().length != 0) { + basePackages = springBootApplication.scanBasePackages(); + } else { + String packageName = ClassUtils.getPackageName(springBootApplicationBean.getClass().getName()); + basePackages = new String[1]; + basePackages[0] = packageName; + } + } + + EndpointClassPathScanner scanHandle = new EndpointClassPathScanner((BeanDefinitionRegistry) context, false); + if (resourceLoader != null) { + scanHandle.setResourceLoader(resourceLoader); + } + + for (String basePackage : basePackages) { + scanHandle.doScan(basePackage); + } + } + + private void init() { + for (Map.Entry entry : addressWebsocketServerMap.entrySet()) { + WebsocketServer websocketServer = entry.getValue(); + try { + websocketServer.init(); + PojoEndpointServer pojoEndpointServer = websocketServer.getPojoEndpointServer(); + StringJoiner stringJoiner = new StringJoiner(","); + pojoEndpointServer.getPathMatcherSet().forEach(pathMatcher -> stringJoiner.add("'" + pathMatcher.getPattern() + "'")); + logger.info(String.format("\033[34mNetty WebSocket started on port: %s with context path(s): %s .\033[0m", pojoEndpointServer.getPort(), stringJoiner.toString())); + } catch (InterruptedException e) { + logger.error(String.format("websocket [%s] init fail", entry.getKey()), e); + } catch (SSLException e) { + logger.error(String.format("websocket [%s] ssl create fail", entry.getKey()), e); + + } + } + } + + private void registerEndpoint(Class endpointClass) { + ServerEndpoint annotation = AnnotatedElementUtils.findMergedAnnotation(endpointClass, ServerEndpoint.class); + if (annotation == null) { + throw new IllegalStateException("missingAnnotation ServerEndpoint"); + } + ServerEndpointConfig serverEndpointConfig = buildConfig(annotation); + + ApplicationContext context = getApplicationContext(); + PojoMethodMapping pojoMethodMapping = null; + try { + pojoMethodMapping = newPojoMethodMapping(endpointClass, context); + } catch (DeploymentException e) { + throw new IllegalStateException("Failed to register ServerEndpointConfig: " + serverEndpointConfig, e); + } + + InetSocketAddress inetSocketAddress = new InetSocketAddress(serverEndpointConfig.getHost(), serverEndpointConfig.getPort()); + String path = resolveAnnotationValue(annotation.value(), String.class, "path"); + + WebsocketServer websocketServer = addressWebsocketServerMap.get(inetSocketAddress); + if (websocketServer == null) { + PojoEndpointServer pojoEndpointServer = new PojoEndpointServer(pojoMethodMapping, serverEndpointConfig, path); + websocketServer = new WebsocketServer(pojoEndpointServer, serverEndpointConfig); + addressWebsocketServerMap.put(inetSocketAddress, websocketServer); + } else { + websocketServer.getPojoEndpointServer().addPathPojoMethodMapping(path, pojoMethodMapping); + } + } + + /** + * 可集成重写,支持自定义PojoMethodMapping + * @param endpointClass + * @param context + * @return + * @throws DeploymentException + */ + protected PojoMethodMapping newPojoMethodMapping(Class endpointClass, ApplicationContext context) + throws DeploymentException { + PojoMethodMapping pojoMethodMapping; + pojoMethodMapping = new PojoMethodMapping(endpointClass, context, beanFactory); + return pojoMethodMapping; + } + + private ServerEndpointConfig buildConfig(ServerEndpoint annotation) { + String host = resolveAnnotationValue(annotation.host(), String.class, "host"); + int port = resolveAnnotationValue(annotation.port(), Integer.class, "port"); + String path = resolveAnnotationValue(annotation.value(), String.class, "value"); + int bossLoopGroupThreads = resolveAnnotationValue(annotation.bossLoopGroupThreads(), Integer.class, "bossLoopGroupThreads"); + int workerLoopGroupThreads = resolveAnnotationValue(annotation.workerLoopGroupThreads(), Integer.class, "workerLoopGroupThreads"); + boolean useCompressionHandler = resolveAnnotationValue(annotation.useCompressionHandler(), Boolean.class, "useCompressionHandler"); + + int optionConnectTimeoutMillis = resolveAnnotationValue(annotation.optionConnectTimeoutMillis(), Integer.class, "optionConnectTimeoutMillis"); + int optionSoBacklog = resolveAnnotationValue(annotation.optionSoBacklog(), Integer.class, "optionSoBacklog"); + + int childOptionWriteSpinCount = resolveAnnotationValue(annotation.childOptionWriteSpinCount(), Integer.class, "childOptionWriteSpinCount"); + int childOptionWriteBufferHighWaterMark = resolveAnnotationValue(annotation.childOptionWriteBufferHighWaterMark(), Integer.class, "childOptionWriteBufferHighWaterMark"); + int childOptionWriteBufferLowWaterMark = resolveAnnotationValue(annotation.childOptionWriteBufferLowWaterMark(), Integer.class, "childOptionWriteBufferLowWaterMark"); + int childOptionSoRcvbuf = resolveAnnotationValue(annotation.childOptionSoRcvbuf(), Integer.class, "childOptionSoRcvbuf"); + int childOptionSoSndbuf = resolveAnnotationValue(annotation.childOptionSoSndbuf(), Integer.class, "childOptionSoSndbuf"); + boolean childOptionTcpNodelay = resolveAnnotationValue(annotation.childOptionTcpNodelay(), Boolean.class, "childOptionTcpNodelay"); + boolean childOptionSoKeepalive = resolveAnnotationValue(annotation.childOptionSoKeepalive(), Boolean.class, "childOptionSoKeepalive"); + int childOptionSoLinger = resolveAnnotationValue(annotation.childOptionSoLinger(), Integer.class, "childOptionSoLinger"); + boolean childOptionAllowHalfClosure = resolveAnnotationValue(annotation.childOptionAllowHalfClosure(), Boolean.class, "childOptionAllowHalfClosure"); + + int readerIdleTimeSeconds = resolveAnnotationValue(annotation.readerIdleTimeSeconds(), Integer.class, "readerIdleTimeSeconds"); + int writerIdleTimeSeconds = resolveAnnotationValue(annotation.writerIdleTimeSeconds(), Integer.class, "writerIdleTimeSeconds"); + int allIdleTimeSeconds = resolveAnnotationValue(annotation.allIdleTimeSeconds(), Integer.class, "allIdleTimeSeconds"); + + int maxFramePayloadLength = resolveAnnotationValue(annotation.maxFramePayloadLength(), Integer.class, "maxFramePayloadLength"); + + boolean useEventExecutorGroup = resolveAnnotationValue(annotation.useEventExecutorGroup(), Boolean.class, "useEventExecutorGroup"); + int eventExecutorGroupThreads = resolveAnnotationValue(annotation.eventExecutorGroupThreads(), Integer.class, "eventExecutorGroupThreads"); + + String sslKeyPassword = resolveAnnotationValue(annotation.sslKeyPassword(), String.class, "sslKeyPassword"); + String sslKeyStore = resolveAnnotationValue(annotation.sslKeyStore(), String.class, "sslKeyStore"); + String sslKeyStorePassword = resolveAnnotationValue(annotation.sslKeyStorePassword(), String.class, "sslKeyStorePassword"); + String sslKeyStoreType = resolveAnnotationValue(annotation.sslKeyStoreType(), String.class, "sslKeyStoreType"); + String sslTrustStore = resolveAnnotationValue(annotation.sslTrustStore(), String.class, "sslTrustStore"); + String sslTrustStorePassword = resolveAnnotationValue(annotation.sslTrustStorePassword(), String.class, "sslTrustStorePassword"); + String sslTrustStoreType = resolveAnnotationValue(annotation.sslTrustStoreType(), String.class, "sslTrustStoreType"); + + String[] corsOrigins = annotation.corsOrigins(); + if (corsOrigins.length != 0) { + for (int i = 0; i < corsOrigins.length; i++) { + corsOrigins[i] = resolveAnnotationValue(corsOrigins[i], String.class, "corsOrigins"); + } + } + Boolean corsAllowCredentials = resolveAnnotationValue(annotation.corsAllowCredentials(), Boolean.class, "corsAllowCredentials"); + + ServerEndpointConfig serverEndpointConfig = new ServerEndpointConfig(host, port, bossLoopGroupThreads, workerLoopGroupThreads + , useCompressionHandler, optionConnectTimeoutMillis, optionSoBacklog, childOptionWriteSpinCount, childOptionWriteBufferHighWaterMark + , childOptionWriteBufferLowWaterMark, childOptionSoRcvbuf, childOptionSoSndbuf, childOptionTcpNodelay, childOptionSoKeepalive + , childOptionSoLinger, childOptionAllowHalfClosure, readerIdleTimeSeconds, writerIdleTimeSeconds, allIdleTimeSeconds + , maxFramePayloadLength, useEventExecutorGroup, eventExecutorGroupThreads + , sslKeyPassword, sslKeyStore, sslKeyStorePassword, sslKeyStoreType + , sslTrustStore, sslTrustStorePassword, sslTrustStoreType + , corsOrigins, corsAllowCredentials); + + return serverEndpointConfig; + } + + private T resolveAnnotationValue(Object value, Class requiredType, String paramName) { + if (value == null) { + return null; + } + TypeConverter typeConverter = beanFactory.getTypeConverter(); + + if (value instanceof String) { + String strVal = beanFactory.resolveEmbeddedValue((String) value); + BeanExpressionResolver beanExpressionResolver = beanFactory.getBeanExpressionResolver(); + if (beanExpressionResolver != null) { + value = beanExpressionResolver.evaluate(strVal, new BeanExpressionContext(beanFactory, null)); + } else { + value = strVal; + } + } + try { + return typeConverter.convertIfNecessary(value, requiredType); + } catch (TypeMismatchException e) { + throw new IllegalArgumentException("Failed to convert value of parameter '" + paramName + "' to required type '" + requiredType.getName() + "'"); + } + } + + @Override + public void setResourceLoader(ResourceLoader resourceLoader) { + this.resourceLoader = resourceLoader; + } +} diff --git a/src/main/java/org/yeauty/standard/WebSocketServerHandler.java b/src/main/java/org/yeauty/standard/WebSocketServerHandler.java index a51a41f258e1b5a157a68409e04692efbabecd7a..d23352335e393f1b08d5c7be7bf029b98d25a6e3 100755 --- a/src/main/java/org/yeauty/standard/WebSocketServerHandler.java +++ b/src/main/java/org/yeauty/standard/WebSocketServerHandler.java @@ -1,59 +1,100 @@ -package org.yeauty.standard; - -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.codec.http.websocketx.*; -import org.yeauty.pojo.PojoEndpointServer; - -class WebSocketServerHandler extends SimpleChannelInboundHandler { - - private final PojoEndpointServer pojoEndpointServer; - - public WebSocketServerHandler(PojoEndpointServer pojoEndpointServer) { - this.pojoEndpointServer = pojoEndpointServer; - } - - @Override - protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception { - handleWebSocketFrame(ctx, msg); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - pojoEndpointServer.doOnError(ctx.channel(), cause); - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - pojoEndpointServer.doOnClose(ctx.channel()); - } - - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - pojoEndpointServer.doOnEvent(ctx.channel(), evt); - } - - private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { - if (frame instanceof TextWebSocketFrame) { - pojoEndpointServer.doOnMessage(ctx.channel(), frame); - return; - } - if (frame instanceof PingWebSocketFrame) { - ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain())); - return; - } - if (frame instanceof CloseWebSocketFrame) { - ctx.writeAndFlush(frame.retainedDuplicate()).addListener(ChannelFutureListener.CLOSE); - return; - } - if (frame instanceof BinaryWebSocketFrame) { - pojoEndpointServer.doOnBinary(ctx.channel(), frame); - return; - } - if (frame instanceof PongWebSocketFrame) { - return; - } - } - +package org.yeauty.standard; + +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.websocketx.*; +import io.netty.util.ReferenceCountUtil; + +import org.yeauty.pojo.PojoEndpointServer; + +class WebSocketServerHandler extends SimpleChannelInboundHandler { + + private final PojoEndpointServer pojoEndpointServer; + + ThreadLocal autoReleaseLocal = new ThreadLocal<>(); + + public WebSocketServerHandler(PojoEndpointServer pojoEndpointServer) { + this.pojoEndpointServer = pojoEndpointServer; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception { + handleWebSocketFrame(ctx, msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + pojoEndpointServer.doOnError(ctx.channel(), cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + pojoEndpointServer.doOnClose(ctx.channel()); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + pojoEndpointServer.doOnEvent(ctx.channel(), evt); + } + + /** + * 执行应用程序自定义的接口方法 + * @param ctx + * @param frame + * @return + */ + private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { + autoReleaseLocal.set(true); + if (frame instanceof TextWebSocketFrame) { + pojoEndpointServer.doOnMessage(ctx.channel(), frame); + return ; + } + if (frame instanceof PingWebSocketFrame) { + ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain())); + return ; + } + if (frame instanceof CloseWebSocketFrame) { + ctx.writeAndFlush(frame.retainedDuplicate()).addListener(ChannelFutureListener.CLOSE); + return ; + } + if (frame instanceof BinaryWebSocketFrame) { + boolean bool = pojoEndpointServer.doOnBinary(ctx.channel(), frame); + autoReleaseLocal.set(bool); + return; + } + if (frame instanceof PongWebSocketFrame) { + return ; + } + } + + /** + * 取消自动关闭byteBuf + */ + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + + boolean release = true; + try { + if (acceptInboundMessage(msg)) { + @SuppressWarnings("unchecked") + WebSocketFrame imsg = (WebSocketFrame) msg; + channelRead0(ctx, imsg); + + } else { + release = false; + ctx.fireChannelRead(msg); + + } + } finally { + boolean autoRelease=autoReleaseLocal.get()==null||autoReleaseLocal.get(); + if ( autoRelease&&release) { + ReferenceCountUtil.release(msg); + } + autoReleaseLocal.remove();//防止线程数据脏读 + + } + } + } \ No newline at end of file diff --git a/src/main/java/org/yeauty/support/ByteBufMethodArgumentResolver.java b/src/main/java/org/yeauty/support/ByteBufMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..7703ef216b84563c554c7dd2c8c7f8fedb507c6e --- /dev/null +++ b/src/main/java/org/yeauty/support/ByteBufMethodArgumentResolver.java @@ -0,0 +1,32 @@ +package org.yeauty.support; +import org.springframework.core.MethodParameter; +import org.yeauty.annotation.OnBinary; +import org.yeauty.annotation.OnMessage; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +/** + * 支持使用直接内存,ByteMethodArgumentResolver将堆外内存加载到堆内存了 + * @author giant + * + */ +public class ByteBufMethodArgumentResolver implements MethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + // TODO Auto-generated method stub + return (parameter.getMethod().isAnnotationPresent(OnBinary.class) + ||parameter.getMethod().isAnnotationPresent(OnMessage.class) ) + &¶meter.getParameterType().equals(ByteBuf.class); + } + + @Override + public Object resolveArgument(MethodParameter parameter, Channel channel, Object object) throws Exception { + WebSocketFrame binaryWebSocketFrame = (WebSocketFrame) object;//netty默认是直接内存 + // 解析参数的逻辑 + ByteBuf content = binaryWebSocketFrame.content(); + return content; + } +} \ No newline at end of file diff --git a/src/main/java/org/yeauty/support/ChannelMethodArgumentResolver.java b/src/main/java/org/yeauty/support/ChannelMethodArgumentResolver.java new file mode 100644 index 0000000000000000000000000000000000000000..ff2389eb18ecf876f63082ae827b3762abc4787b --- /dev/null +++ b/src/main/java/org/yeauty/support/ChannelMethodArgumentResolver.java @@ -0,0 +1,29 @@ +package org.yeauty.support; + +import org.springframework.core.MethodParameter; +import org.yeauty.annotation.OnBinary; + +import io.netty.channel.Channel; + +public class ChannelMethodArgumentResolver implements MethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + // TODO Auto-generated method stub + return parameter.getMethod().isAnnotationPresent(OnBinary.class) &¶meter.getParameterType().equals(Channel.class); + } + + @Override + public Object resolveArgument(MethodParameter parameter, Channel channel, Object object) throws Exception { +// BinaryWebSocketFrame binaryWebSocketFrame = (BinaryWebSocketFrame) object;//netty默认是直接内存 +// // 解析参数的逻辑 +//// ByteBuf content = binaryWebSocketFrame.content(); + +// return true; + return channel; + } +} + + + + diff --git a/src/main/java/org/yeauty/util/FileUtils.java b/src/main/java/org/yeauty/util/FileUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..e3d7ce8b5579e4586538656951dbc7a0e3bb1257 --- /dev/null +++ b/src/main/java/org/yeauty/util/FileUtils.java @@ -0,0 +1,94 @@ +package org.yeauty.util; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import javax.servlet.http.HttpServletResponse; + +import io.netty.buffer.ByteBuf; +import io.netty.util.ReferenceCountUtil; + +public class FileUtils { + + /** + * 直接内存直接写文件 + * @param byteBuf + * @param filePath + * @throws Exception + */ + public static void writeByteBufToFile(ByteBuf byteBuf, String filePath) throws Exception { + try (RandomAccessFile raf = new RandomAccessFile(filePath, "rw"); + FileChannel fileChannel = raf.getChannel()) { + ByteBuffer buffer = byteBuf.nioBuffer(); +// buffer.allocateDirect(byteBuf.readableBytes());//不需要,byteBuf本来就允许直接内存 +// buffer.flip(); // 切换为读模式 + + // 5. 将数据写入文件(零拷贝) + while (buffer.hasRemaining()) { + fileChannel.write(buffer); + } + + // 6. 强制刷盘 + fileChannel.force(true); + try { + if (fileChannel != null) fileChannel.close(); + if (raf != null) raf.close(); + } catch (IOException e) { + e.printStackTrace(); + } + }finally{ + + byteBuf.release(); + } + } + + + public static void writeByteBufToFile(ByteBuf byteBuf, OutputStream output) throws Exception { + try{ + ByteBuffer buffer = byteBuf.nioBuffer(); +// buffer.allocateDirect(byteBuf.readableBytes());//不需要,byteBuf本来就允许直接内存 +// buffer.flip(); // 切换为读模式 +// byteBuf.readBytes(output, byteBuf.readableBytes()); + WritableByteChannel channel = Channels.newChannel(output); + // 分块写入数据 + while (buffer.hasRemaining()) {//使用通道(如WritableByteChannel)将数据写入HttpServletResponse的输出流。 + channel.write(buffer); + } + }finally{ + byteBuf.release(); + } + + } + + public static void writeByteBufToFile(ByteBuf byteBuf, HttpServletResponse response) throws Exception { +// ByteBufAllocator.DEFAULT.directBuffer(); + try{ + ByteBuffer buffer = byteBuf.nioBuffer(); +// buffer.allocateDirect(byteBuf.readableBytes());//不需要,byteBuf本来就允许直接内存 +// buffer.flip(); // 切换为读模式 + OutputStream outputStream = response.getOutputStream(); + WritableByteChannel channel = Channels.newChannel(outputStream); + // 设置响应头等元数据 + response.setContentType("application/octet-stream"); + response.setContentLength(byteBuf.readableBytes()); + // 分块写入数据 + while (buffer.hasRemaining()) { + channel.write(buffer); + } + + // 确保数据刷出 + outputStream.flush(); + + }finally{ + // 释放ByteBuf资源 + ReferenceCountUtil.safeRelease(byteBuf); + + } + + } +}