使用netty模拟简单的rpc框架,nettyrpc框架
使用netty模拟简单的rpc框架,nettyrpc框架
最近在学习netty,便试着写了这个demo,内容不严谨,只是简单模拟。
这是dubbo的图
首先需要一个注册中心:
注册中心用Map保存服务提供者的信息,这里简单的保存提供者的地址,有 注册服务 和 获取提供服务的地址 列表两个方法。
public class DefaultRegistryCenterServer implements RegistryCenterServer{
private Map<String,Set<SocketAddress>> registryCenter = new ConcurrentHashMap<String, Set<SocketAddress>>();
private EventLoopGroup boss;
private EventLoopGroup worker;
private ServerBootstrap serverBootstrap;
private ChannelHandlerAdapter handler= new DefaultRegisterHandler(this);
private int port;
public DefaultRegistryCenterServer(int port) {
this.port = port;
}
@Override
public void start() {
boss = new NioEventLoopGroup();
worker = new NioEventLoopGroup();
serverBootstrap = new ServerBootstrap()
.group(boss, worker)
.channel(NioServerSocketChannel.class)
.localAddress(new InetSocketAddress(port))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel sChannel) throws Exception {
sChannel.pipeline().addLast("encoder",new StringEncoder());
sChannel.pipeline().addLast(new ObjectDecoder(ClassResolvers.cacheDisabled(this
.getClass().getClassLoader())));
sChannel.pipeline().addLast(handler);
}
})
.option(ChannelOption.SO_BACKLOG, 128).childOption(ChannelOption.SO_KEEPALIVE,true);
ChannelFuture future;
try {
future = serverBootstrap.bind(port).sync();
System.out.println("Server start listen at" + port);
future.channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
@Override
public void register(String server,SocketAddress address) {
synchronized (this) {
if(!registryCenter.containsKey(server)) {
registryCenter.put(server, new ConcurrentSkipListSet<SocketAddress>());
}
}
registryCenter.get(server).add(address);
System.out.println(registryCenter);
}
@Override
public Set getServers(String service) {
return registryCenter.get(service);
}
}
Handler:
根据客户端发送来的命令 分别处理 注册 和 查找服务。
查找服务会将该服务的地址列表用字符串返回
@Sharable
public class DefaultRegisterHandler extends ChannelHandlerAdapter {
private RegistryCenterServer registryCenterServer;
public DefaultRegisterHandler(RegistryCenterServer registryCenterServer) {
this.registryCenterServer = registryCenterServer;
}
@SuppressWarnings("unchecked")
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof ServerCommand) {
byte command = ((ServerCommand) msg).getCommand();
int port = ((ServerCommand) msg).getProviderPort();
if (command == ServerCommand.REG_SERVER) {
// System.out.println("reg_server");
registryCenterServer.register(((ServerCommand) msg).getServiceInterface(), new InetSocketAddress(port));
} else if (command == ServerCommand.FIND_SERVER) {
// System.out.println("find_server");
Set<SocketAddress> servers = registryCenterServer
.getServers(((ServerCommand) msg).getServiceInterface());
StringBuilder sb = new StringBuilder();
if (servers != null) {
for (SocketAddress socketAddress : servers) {
sb.append(((InetSocketAddress) socketAddress).getHostName()).append(":")
.append(((InetSocketAddress) socketAddress).getPort()).append(";");
}
}
ctx.writeAndFlush(((ServerCommand) msg).getServiceInterface() + "//" + sb.toString() + "*$*");
}
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}
}
这是客户端向注册中心通信用的对象:
有服务名称,服务的端口
public class ServerCommand implements Serializable{
transient public static final byte REG_SERVER = 1;
transient public static final byte FIND_SERVER = 2;
private byte command;
private String serviceInterface;
private int providerPort;
private ServerCommand(byte command, String serviceInterface,int providerPort) {
this.command = command;
this.serviceInterface = serviceInterface;
this.providerPort = providerPort;
}
public static ServerCommand addProvider(String serviceInterface,int providerPort) {
return new ServerCommand(REG_SERVER,serviceInterface,providerPort);
}
public static ServerCommand addConsumer(String serviceInterface) {
return new ServerCommand(FIND_SERVER,serviceInterface,0);
}
注册中心启动类:
public class Application {
public static void main(String[] args) {
RegistryCenterServer server = new DefaultRegistryCenterServer(8088);
server.start();
}
}
然后是rpc应用:
Configuration:
应用的线程组,暴露的服务,调用的服务等信息保存在这
暴露服务保存在provider中,需要调用的服务保存在consumer中
public class Configuration {
private EventLoopGroup boss;
private EventLoopGroup worker;
private Consumer consumer;
private Provider provider;
private RegistryCenterClient registryCenterClient;
...
}
连接注册中心的类:
isConnecting()是用来判断连接成功没有的。连接成功才能做后续操作
sendCommand()是向注册中心服务端发送命令用的。
public class RegistryCenterClient extends Thread{
private Configuration configuration;
private String host;
private int port;
private Bootstrap bootstrap;
volatile private Channel channel;
public RegistryCenterClient(String host, int port,Configuration configuration) {
this.host = host;
this.port = port;
this.configuration = configuration;
}
@Override
public void run() {
bootstrap = new Bootstrap().group(configuration.getWorker()).channel(NioSocketChannel.class).option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel sChannel) throws Exception {
ByteBuf delimiter = Unpooled.copiedBuffer("*$*".getBytes());
sChannel.pipeline().addLast(new DelimiterBasedFrameDecoder(1024 * 2, delimiter))
.addLast("decoder", new StringDecoder()).addLast(new ObjectEncoder())
.addLast(new ClientHandler(configuration.getConsumer()));
}
});
try {
ChannelFuture future = bootstrap.connect(host, port).sync();
this.channel = future.channel();
future.channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
public boolean isConnecting() {
return channel==null?false:true;
}
public void sendCommand(ServerCommand... commands) {
for (ServerCommand serverCommand : commands) {
channel.write(serverCommand);
}
channel.flush();
}
}
Handler:
如果向服务器发送查找服务的命令,这里会收到服务器返回的服务地址列表信息,字符串拆分后保存在consumer中
public class ClientHandler extends ChannelHandlerAdapter {
private Consumer consumer;
public ClientHandler(Consumer consumer) {
this.consumer = consumer;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
System.out.println("register message: " + msg);
String command = msg.toString();
String[] strs = command.split("//");
if(strs.length > 1) {
String[] addressStr = strs[1].split(";");
InetSocketAddress[] address = new InetSocketAddress[addressStr.length];
for (int i = 0; i < address.length; i++) {
System.out.println(addressStr[i]);
String[] data = addressStr[i].split(":");
address[i] = new InetSocketAddress(data[0], Integer.parseInt(data[1]));
}
consumer.setAddressList(strs[0], address);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}
}
从注册中心获取服务地址列表后,消费者会直接与生产者建立连接进行通信,通信使用的两个对象:
public class Req_Message implements Serializable {
private long id;
private String serviceInterface;
private String methodName;
private Object[] param;
}
public class Res_Message implements Serializable{
Long id;
String msg;
Object result;
}
RpcController:
该类是用来提供给使用者调用调用rpc框架的类
需要暴露服务:createProvider(),得到provider对象,往provider对象设置暴露服务的信息。
需要调用服务:createConsumer(),得到consumer对象,从consumer对象中设置调用服务的信息。
start()启动。
初始化配置信息->连接注册中心->启动容器
启动完后,通过newServiceProxy(interface)创建服务对象
public class RpcController {
private Configuration configuration;
private Invoker invoker;
private Exporter exporter;
public RpcController(String registerHost, int registerPort) {
configuration = new Configuration();
configuration.setRegistryCenterClient(new RegistryCenterClient(registerHost, registerPort,configuration));
}
public void start() {
try {
initEventLoopGroup();
connectRegistryCenter();
startContainer();
System.out.println("startup");
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
private void startContainer() throws ClassNotFoundException {
if(configuration.getProvider() != null) {
exporter = new Exporter(configuration);
exporter.start();
Provider provider = configuration.getProvider();
provider.registerService();
}
if(configuration.getConsumer() != null) {
invoker = new Invoker(configuration);
invoker.start();
Consumer consumer = configuration.getConsumer();
consumer.findService();
}
}
private void connectRegistryCenter() {
configuration.getRegistryCenterClient().start();
while(!configuration.getRegistryCenterClient().isConnecting());
}
private void initEventLoopGroup() {
configuration.setWorker(new NioEventLoopGroup());
if(configuration.getProvider() != null) {
configuration.setBoss(new NioEventLoopGroup());
}
}
public Consumer createConsumer() {
configuration.setConsumer(new Consumer(configuration.getRegistryCenterClient()));
return configuration.getConsumer();
}
public Provider createProvider(int port) {
configuration.setProvider(new Provider(configuration.getRegistryCenterClient(),port));
return configuration.getProvider();
}
public <T> T newServiceProxy(Class<T> serviceInterface){
return invoker.newServiceProxy(serviceInterface);
}
}
如果创建了provider对象,就会启动Exporter。Exporter会启动一个ServerSocket向外提供服务
接收调用者传来的接口名,方法名,参数等信息。然后返回处理后的结果
public class Exporter extends Thread{
private Configuration configuration;
private ServerBootstrap serverBootstrap;
private ChannelHandler handler;
public Exporter(Configuration configuration) {
this.configuration = configuration;
this.handler = new ExporterHandler(configuration);
}
@Override
public void run() {
serverBootstrap = new ServerBootstrap().group(configuration.getBoss(), configuration.getWorker())
.channel(NioServerSocketChannel.class)
.localAddress(new InetSocketAddress(configuration.getProvider().getPort()))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel sChannel) throws Exception {
sChannel.pipeline().addLast(new ObjectEncoder())
.addLast(new ObjectDecoder(
ClassResolvers.cacheDisabled(this.getClass().getClassLoader())))
.addLast(handler);
}
}).option(ChannelOption.SO_BACKLOG, 128).childOption(ChannelOption.SO_KEEPALIVE, true);
ChannelFuture future;
try {
future = serverBootstrap.bind(configuration.getProvider().getPort()).sync();
future.channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
handler:
@Sharable
public class ExporterHandler extends ChannelHandlerAdapter {
private Configuration configuration;
public ExporterHandler(Configuration configuration) {
this.configuration = configuration;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof Req_Message) {
long id = ((Req_Message) msg).getId();
String serviceInterface = ((Req_Message) msg).getServiceInterface();
String methodName = ((Req_Message) msg).getMethodName();
Object[] param = ((Req_Message) msg).getParam();
Provider provider = configuration.getProvider();
Object service = provider.getService(serviceInterface);
Object result = null;
if (param != null && param.length > 0) {
Class[] pts = new Class[param.length];
for (int i = 0; i < pts.length; i++) {
pts[i] = param[i].getClass();
}
result = service.getClass().getMethod(methodName, pts).invoke(service, param);
} else {
System.out.println(service.getClass());
result = service.getClass().getMethod(methodName, null).invoke(service);
}
Res_Message res_Message = new Res_Message(id, "", result);
ctx.writeAndFlush(res_Message);
}
super.channelRead(ctx, msg);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}
}
如果创建了consumer对象,就会启动Invoker。Invoker会在创建服务代理对象proxy时在consumer中找到该服务的地址列表,随机选一个进行连接。
public class Invoker extends Thread{
private Configuration configuration;
private Bootstrap bootstrap;
private ServiceProxy serviceProxy;
private ChannelHandler handler;
public Invoker(Configuration configuration) {
this.configuration = configuration;
this.serviceProxy = new ServiceProxy(configuration);
this.handler = new InvokerHandler(configuration);
}
@Override
public void run() {
bootstrap = new Bootstrap().group(configuration.getWorker()).channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true).handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel sChannel) throws Exception {
sChannel.pipeline().addLast(new ObjectEncoder())
.addLast(new ObjectDecoder(
ClassResolvers.cacheDisabled(this.getClass().getClassLoader())))
.addLast(handler);
}
});
}
public <T> T newServiceProxy(Class<T> serviceInterface) {
Consumer consumer = configuration.getConsumer();
Channel channel = consumer.getChannel(serviceInterface.getName());
if (channel == null) {
InetSocketAddress address = consumer.getAddress(serviceInterface.getName());
System.out.println("service address:" + address);
try {
ChannelFuture channelFuture = bootstrap.connect(address).sync();
channel = channelFuture.channel();
consumer.setChannel(serviceInterface.getName(), channel);
// channelFuture.channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
Object proxy = serviceProxy.newProxy(serviceInterface);
return (T) proxy;
}
handler:
在接收到信息时调用consumer中的setResult(),这个方法会唤醒等待的线程。
@Sharable
public class InvokerHandler extends ChannelHandlerAdapter {
private Configuration configuration;
public InvokerHandler(Configuration configuration) {
this.configuration = configuration;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if(msg instanceof Res_Message) {
Long id = ((Res_Message) msg).getId();
Object result = ((Res_Message) msg).getResult();
configuration.getConsumer().setResult(id, result);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}
}
代理对象调用方法时会向提供者发送请求,WaitingResponse()生成一个响应对象Response存放Map中,然后线程进入等待,当ChannelRead()得到服务端返回的结果后会重新把线程唤醒,并讲结果返回。
public class ServiceProxy implements InvocationHandler{
private Configuration configuration;
public ServiceProxy(Configuration configuration) {
this.configuration = configuration;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
Consumer consumer = configuration.getConsumer();
String serviceName = proxy.getClass().getInterfaces()[0].getName();
Channel channel = consumer.getChannel(serviceName);
Long id = consumer.getRequestId();
Req_Message msg = new Req_Message(id, serviceName, method.getName(), args);
channel.write(msg);
channel.flush();
consumer.WaitingResponse(id);
Object result = consumer.getRequest(id);
consumer.removeResponse(id);
return result;
}
public Object newProxy(Class serviceInterface) {
return Proxy.newProxyInstance(ServiceProxy.class.getClassLoader(), new Class[] {serviceInterface}, this);
}
}
Consumer和provider:
public class Consumer {
private RegistryCenterClient client;
private AtomicLong requestId;
Map<Long, Response> responses;
private Map<String, Service> services;
{
requestId = new AtomicLong();
services = new HashMap<>();
responses = new ConcurrentHashMap<>();
}
protected Consumer(RegistryCenterClient client) {
this.client = client;
}
public void findService() throws ClassNotFoundException {
Set<Entry<String, Service>> entrySet = services.entrySet();
for (Entry<String, Service> entry : entrySet) {
String serviceName = entry.getKey();
client.sendCommand(ServerCommand.addConsumer(serviceName));
}
}
public void addService(Class... serviceInterface) {
for (Class service : serviceInterface) {
services.put(service.getName(), new Service(service));
}
}
public boolean containsService(String serviceName) {
return services.containsKey(serviceName);
}
public void setAddressList(String serviceName, InetSocketAddress... address) {
Service service = services.get(serviceName);
if (service != null) {
service.addressList = address;
}
}
public Channel getChannel(String serviceName) {
return services.get(serviceName).channel;
}
public Channel setChannel(String serviceName, Channel channel) {
return services.get(serviceName).channel = channel;
}
public InetSocketAddress getAddress(String service) {
InetSocketAddress[] addressList = services.get(service).addressList;
long start = System.currentTimeMillis();
while (addressList == null) {
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
addressList = services.get(service).addressList;
if (System.currentTimeMillis() - start > 3000) {
break;
}
}
Random rand = new Random();
int index = rand.nextInt(addressList.length);
InetSocketAddress address = addressList[index];
return address;
}
public Long getRequestId() {
return requestId.incrementAndGet();
}
public void WaitingResponse(long id) throws InterruptedException {
Response response = new Response();
responses.put(id, response);
response.latch = new CountDownLatch(1);
response.latch.await(1, TimeUnit.SECONDS);
}
public void setResult(long id, Object result) {
Response response = responses.get(id);
if (response != null) {
response.result = result;
response.latch.countDown();
}
}
public void removeResponse(long id) {
responses.remove(id);
}
public Object getRequest(long id) {
return responses.get(id).result;
}
private class Service {
Class service;
InetSocketAddress[] addressList;
InetSocketAddress curAddress;
Channel channel;
ServiceProxy proxy;
public Service(Class service) {
super();
this.service = service;
}
}
private class Response {
CountDownLatch latch;
Object result;
}
}
consumer中的WaitingResponse()通过CountDownLatch实现线程等待和唤醒。requestId是通过AtomicLong生成来保证每一个id都是唯一的。
public class Provider {
private RegistryCenterClient client;
private int port;
private Map<String,Object> services;
protected Provider(RegistryCenterClient client,int port) {
this.client = client;
this.port = port;
services = new HashMap<>();
}
public void registerService() throws ClassNotFoundException {
Set<Entry<String, Object>> entrySet = services.entrySet();
for (Entry<String, Object> entry : entrySet) {
String serviceName = entry.getKey();
client.sendCommand(ServerCommand.addProvider(serviceName, port));
}
}
public void addService(Class serviceInterface,Object serviceImpl) {
services.put(serviceInterface.getName(), serviceImpl);
}
public Object getService(String serviceInterface) {
return services.get(serviceInterface);
}
public int getPort() {
return port;
}
}
简单测试:
服务接口:
public interface TestInterface {
void hello();
Integer sum(Integer a,Integer b);
}
实现:
public class TestImpl implements TestInterface{
@Override
public void hello() {
System.out.println("hello rpc");
}
@Override
public Integer sum(Integer a, Integer b) {
return a+b;
}
}
启动注册中心
生产者:
@Test
public void provider1() throws IOException {
RpcController rpc = new RpcController("127.0.0.1", 8088);
Provider createProvider = rpc.createProvider(8089);
createProvider.addService(TestInterface.class, new TestImpl());
rpc.start();
System.out.println("provider1 start");
System.in.read();
}
消费者:
@Test
public void consumer1() throws IOException {
RpcController rpc = new RpcController("127.0.0.1", 8088);
Consumer createConsumer = rpc.createConsumer();
createConsumer.addService(TestInterface.class);
rpc.start();
TestInterface newServiceProxy = rpc.newServiceProxy(TestInterface.class);
newServiceProxy.hello();
System.out.println("2+1=" + newServiceProxy.sum(2, 1));
System.out.println("2+9=" +newServiceProxy.sum(2, 9));
System.in.read();
}
结果:
启动两个生产者,两个消费者:
相关文章
- 暂无相关文章
用户点评