001    // =================================================================================================
002    // Copyright 2011 Twitter, Inc.
003    // -------------------------------------------------------------------------------------------------
004    // Licensed under the Apache License, Version 2.0 (the "License");
005    // you may not use this work except in compliance with the License.
006    // You may obtain a copy of the License in the LICENSE file, or at:
007    //
008    //  http://www.apache.org/licenses/LICENSE-2.0
009    //
010    // Unless required by applicable law or agreed to in writing, software
011    // distributed under the License is distributed on an "AS IS" BASIS,
012    // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013    // See the License for the specific language governing permissions and
014    // limitations under the License.
015    // =================================================================================================
016    
017    package com.twitter.common.zookeeper;
018    
019    import java.io.ByteArrayInputStream;
020    import java.io.ByteArrayOutputStream;
021    import java.io.IOException;
022    import java.net.InetSocketAddress;
023    import java.util.Map;
024    import java.util.Set;
025    import java.util.logging.Level;
026    import java.util.logging.Logger;
027    
028    import javax.annotation.Nullable;
029    
030    import com.google.common.annotations.VisibleForTesting;
031    import com.google.common.base.Joiner;
032    import com.google.common.base.Preconditions;
033    import com.google.common.base.Predicate;
034    import com.google.common.base.Predicates;
035    import com.google.common.base.Throwables;
036    import com.google.common.cache.CacheBuilder;
037    import com.google.common.cache.CacheLoader;
038    import com.google.common.cache.LoadingCache;
039    import com.google.common.collect.ImmutableList;
040    import com.google.common.collect.ImmutableSet;
041    import com.google.common.collect.ImmutableSortedSet;
042    import com.google.common.collect.Iterables;
043    import com.google.common.collect.MapDifference;
044    import com.google.common.collect.MapDifference.ValueDifference;
045    import com.google.common.collect.Maps;
046    import com.google.common.collect.Sets;
047    import com.google.common.collect.Sets.SetView;
048    import com.google.common.util.concurrent.UncheckedExecutionException;
049    import com.google.gson.GsonBuilder;
050    
051    import org.apache.zookeeper.KeeperException;
052    import org.apache.zookeeper.KeeperException.NoNodeException;
053    import org.apache.zookeeper.WatchedEvent;
054    import org.apache.zookeeper.Watcher;
055    import org.apache.zookeeper.Watcher.Event.KeeperState;
056    import org.apache.zookeeper.ZooDefs;
057    import org.apache.zookeeper.data.ACL;
058    
059    import com.twitter.common.args.Arg;
060    import com.twitter.common.args.CmdLine;
061    import com.twitter.common.base.Command;
062    import com.twitter.common.base.Function;
063    import com.twitter.common.base.Supplier;
064    import com.twitter.common.io.Codec;
065    import com.twitter.common.io.CompatibilityCodec;
066    import com.twitter.common.io.JsonCodec;
067    import com.twitter.common.io.ThriftCodec;
068    import com.twitter.common.util.BackoffHelper;
069    import com.twitter.common.zookeeper.Group.GroupChangeListener;
070    import com.twitter.common.zookeeper.Group.JoinException;
071    import com.twitter.common.zookeeper.Group.Membership;
072    import com.twitter.common.zookeeper.Group.WatchException;
073    import com.twitter.common.zookeeper.ZooKeeperClient.ZooKeeperConnectionException;
074    import com.twitter.thrift.Endpoint;
075    import com.twitter.thrift.ServiceInstance;
076    import com.twitter.thrift.Status;
077    
078    /**
079     * Implementation of {@link ServerSet}.
080     *
081     * @author John Sirois
082     */
083    public class ServerSetImpl implements ServerSet {
084      private static final Logger LOG = Logger.getLogger(ServerSetImpl.class.getName());
085    
086      @CmdLine(name = "serverset_encode_json",
087               help = "If true, use JSON for encoding server set information. Defaults to false (use Thrift).")
088      private static final Arg<Boolean> ENCODE_JSON = Arg.create(false);
089    
090      private final ZooKeeperClient zkClient;
091      private final Group group;
092      private final Codec<ServiceInstance> codec;
093      private final BackoffHelper backoffHelper;
094    
095      /**
096       * Creates a new ServerSet using open ZooKeeper node ACLs.
097       *
098       * @param zkClient the client to use for interactions with ZooKeeper
099       * @param path the name-service path of the service to connect to
100       */
101      public ServerSetImpl(ZooKeeperClient zkClient, String path) {
102        this(zkClient, ZooDefs.Ids.OPEN_ACL_UNSAFE, path);
103      }
104    
105      /**
106       * Creates a new ServerSet for the given service {@code path}.
107       *
108       * @param zkClient the client to use for interactions with ZooKeeper
109       * @param acl the ACL to use for creating the persistent group path if it does not already exist
110       * @param path the name-service path of the service to connect to
111       */
112      public ServerSetImpl(ZooKeeperClient zkClient, Iterable<ACL> acl, String path) {
113        this(zkClient, new Group(zkClient, acl, path), createDefaultCodec());
114      }
115    
116      /**
117       * Creates a new ServerSet using the given service {@code group}.
118       *
119       * @param zkClient the client to use for interactions with ZooKeeper
120       * @param group the server group
121       */
122      public ServerSetImpl(ZooKeeperClient zkClient, Group group) {
123        this(zkClient, group, createDefaultCodec());
124      }
125    
126      /**
127       * Creates a new ServerSet using the given service {@code group} and a custom {@code codec}.
128       *
129       * @param zkClient the client to use for interactions with ZooKeeper
130       * @param group the server group
131       * @param codec a codec to use for serializing and de-serializing the ServiceInstance data to and
132       *     from a byte array
133       */
134      public ServerSetImpl(ZooKeeperClient zkClient, Group group, Codec<ServiceInstance> codec) {
135        this.zkClient = Preconditions.checkNotNull(zkClient);
136        this.group = Preconditions.checkNotNull(group);
137        this.codec = Preconditions.checkNotNull(codec);
138    
139        // TODO(John Sirois): Inject the helper so that backoff strategy can be configurable.
140        backoffHelper = new BackoffHelper();
141      }
142    
143      @VisibleForTesting
144      ZooKeeperClient getZkClient() {
145        return zkClient;
146      }
147    
148      @Override
149      public EndpointStatus join(InetSocketAddress endpoint,
150          Map<String, InetSocketAddress> additionalEndpoints, Status status)
151          throws JoinException, InterruptedException {
152        Preconditions.checkNotNull(endpoint);
153        Preconditions.checkNotNull(additionalEndpoints);
154        Preconditions.checkNotNull(status);
155    
156        final MemberStatus memberStatus = new MemberStatus(endpoint, additionalEndpoints, status);
157        Supplier<byte[]> serviceInstanceSupplier = new Supplier<byte[]>() {
158          @Override public byte[] get() {
159            return memberStatus.serializeServiceInstance();
160          }
161        };
162        final Membership membership = group.join(serviceInstanceSupplier);
163    
164        return new EndpointStatus() {
165          @Override public void update(Status status) throws UpdateException {
166            Preconditions.checkNotNull(status);
167            memberStatus.updateStatus(membership, status);
168          }
169        };
170      }
171    
172      @Override
173      public void monitor(final HostChangeMonitor<ServiceInstance> monitor) throws MonitorException {
174        ServerSetWatcher serverSetWatcher = new ServerSetWatcher(zkClient, monitor);
175        try {
176          serverSetWatcher.watch();
177        } catch (WatchException e) {
178          throw new MonitorException("ZooKeeper watch failed.", e);
179        } catch (InterruptedException e) {
180          throw new MonitorException("Interrupted while watching ZooKeeper.", e);
181        }
182      }
183    
184      private class MemberStatus {
185        private final InetSocketAddress endpoint;
186        private final Map<String, InetSocketAddress> additionalEndpoints;
187        private volatile Status status;
188    
189        private MemberStatus(InetSocketAddress endpoint,
190            Map<String, InetSocketAddress> additionalEndpoints, Status status) {
191    
192          this.endpoint = endpoint;
193          this.additionalEndpoints = additionalEndpoints;
194          this.status = status;
195        }
196    
197        synchronized void updateStatus(Membership membership, Status status) throws UpdateException {
198          if (this.status != status) {
199            this.status = status;
200            if (Status.DEAD == status) {
201              try {
202                membership.cancel();
203              } catch (JoinException e) {
204                throw new UpdateException(
205                    "Failed to auto-cancel group membership on transition to DEAD status", e);
206              }
207            } else {
208              try {
209                membership.updateMemberData();
210              } catch (Group.UpdateException e) {
211                throw new UpdateException(
212                    "Failed to update service data for: " + membership.getMemberPath(), e);
213              }
214            }
215          }
216        }
217    
218        byte[] serializeServiceInstance() {
219          ServiceInstance serviceInstance =
220              new ServiceInstance(toEndpoint(endpoint),
221                  Maps.transformValues(additionalEndpoints, TO_ENDPOINT), status);
222          LOG.info("updating endpoint data to:\n\t" + serviceInstance);
223          ByteArrayOutputStream output = new ByteArrayOutputStream();
224          try {
225            codec.serialize(serviceInstance, output);
226          } catch (IOException e) {
227            throw new IllegalStateException("Unexpected problem serializing thrift struct: " +
228                                            serviceInstance + " to a byte[]", e);
229          }
230          return output.toByteArray();
231        }
232      }
233    
234      private static final Function<InetSocketAddress, Endpoint> TO_ENDPOINT =
235          new Function<InetSocketAddress, Endpoint>() {
236            @Override public Endpoint apply(InetSocketAddress address) {
237              return toEndpoint(address);
238            }
239          };
240    
241      private static Endpoint toEndpoint(InetSocketAddress address) {
242        return new Endpoint(address.getHostName(), address.getPort());
243      }
244    
245      private static class ServiceInstanceFetchException extends RuntimeException {
246        ServiceInstanceFetchException(String message, Throwable cause) {
247          super(message, cause);
248        }
249      }
250    
251      private static class ServiceInstanceDeletedException extends RuntimeException {
252        ServiceInstanceDeletedException(String path) {
253          super(path);
254        }
255      }
256    
257      private static final Function<ServiceInstance, Endpoint> GET_PRIMARY_ENDPOINT =
258          new Function<ServiceInstance, Endpoint>() {
259            @Override public Endpoint apply(ServiceInstance serviceInstance) {
260              return serviceInstance.getServiceEndpoint();
261            }
262          };
263    
264      private class ServerSetWatcher {
265        private final ZooKeeperClient zkClient;
266        private final HostChangeMonitor<ServiceInstance> monitor;
267        @Nullable private ImmutableSet<ServiceInstance> serverSet;
268    
269        ServerSetWatcher(ZooKeeperClient zkClient, HostChangeMonitor<ServiceInstance> monitor) {
270          this.zkClient = zkClient;
271          this.monitor = monitor;
272        }
273    
274        public void watch() throws WatchException, InterruptedException {
275          zkClient.registerExpirationHandler(new Command() {
276            @Override public void execute() {
277              // Servers may have changed Status while we were disconnected from ZooKeeper, check and
278              // re-register our node watches.
279              rebuildServerSet();
280            }
281          });
282    
283          group.watch(new GroupChangeListener() {
284            @Override public void onGroupChange(Iterable<String> memberIds) {
285              notifyGroupChange(memberIds);
286            }
287          });
288        }
289    
290        private Watcher serviceInstanceWatcher = new Watcher() {
291          @Override public void process(WatchedEvent event) {
292            if (event.getState() == KeeperState.SyncConnected) {
293              switch (event.getType()) {
294                case None:
295                  // Ignore re-connects that happen while we're watching
296                  break;
297                case NodeDeleted:
298                  // Ignore deletes since these trigger a group change through the group node watch.
299                  break;
300                case NodeDataChanged:
301                  notifyNodeChange(event.getPath());
302                  break;
303                case NodeCreated:
304                  // This watcher is only applied to ephemeral sequential server set member nodes we
305                  // already know the path of (ie: the ephemeral sequential exists and we're told about
306                  // this by reading children).  Its not clear how we can get a NodeCreated event for a
307                  // node we already know about - but this appears to occur in the wild.  Firing a
308                  // change here is safe even if the event path does not represent a server set member.
309                  // The node de-serializer will throw ServiceInstanceFetchException in this case and
310                  // these exceptions are logged and filtered out of member sets.
311                  notifyNodeChange(event.getPath());
312    
313                  // TODO(John Sirois): inject a Statsprovider and track these events in a stat
314                  LOG.warning("Unexpected NodeCreated event while watching service node: " +
315                      event.getPath());
316    
317                  break;
318                default:
319                  LOG.severe("Unexpected event watching service node: " + event);
320              }
321            }
322          }
323        };
324    
325        private ServiceInstance getServiceInstance(final String nodePath) {
326          try {
327            return backoffHelper.doUntilResult(new Supplier<ServiceInstance>() {
328              @Override public ServiceInstance get() {
329                try {
330                  byte[] data = zkClient.get().getData(nodePath, serviceInstanceWatcher, null);
331                  return codec.deserialize(new ByteArrayInputStream(data));
332                } catch (InterruptedException e) {
333                  Thread.currentThread().interrupt();
334                  throw new ServiceInstanceFetchException(
335                      "Interrupted updating service data for: " + nodePath, e);
336                } catch (ZooKeeperConnectionException e) {
337                  LOG.log(Level.WARNING,
338                      "Temporary error trying to updating service data for: " + nodePath, e);
339                  return null;
340                } catch (NoNodeException e) {
341                  invalidateNodePath(nodePath);
342                  throw new ServiceInstanceDeletedException(nodePath);
343                } catch (KeeperException e) {
344                  if (zkClient.shouldRetry(e)) {
345                    LOG.log(Level.WARNING,
346                        "Temporary error trying to update service data for: " + nodePath, e);
347                    return null;
348                  } else {
349                    throw new ServiceInstanceFetchException(
350                        "Failed to update service data for: " + nodePath, e);
351                  }
352                } catch (IOException e) {
353                  throw new ServiceInstanceFetchException(
354                      "Failed to deserialize the ServiceInstance data for: " + nodePath, e);
355                }
356              }
357            });
358          } catch (InterruptedException e) {
359            Thread.currentThread().interrupt();
360            throw new ServiceInstanceFetchException(
361                "Interrupted trying to update service data for: " + nodePath, e);
362          }
363        }
364    
365        private final LoadingCache<String, ServiceInstance> servicesByMemberId =
366            CacheBuilder.newBuilder().build(new CacheLoader<String, ServiceInstance>() {
367              @Override public ServiceInstance load(String memberId) {
368                return getServiceInstance(group.getMemberPath(memberId));
369              }
370            });
371    
372        private void rebuildServerSet() {
373          Set<String> memberIds = ImmutableSet.copyOf(servicesByMemberId.asMap().keySet());
374          servicesByMemberId.invalidateAll();
375          notifyGroupChange(memberIds);
376        }
377    
378        private void notifyNodeChange(String changedPath) {
379          // Invalidate the associated ServiceInstance to trigger a fetch on group notify.
380          String memberId = invalidateNodePath(changedPath);
381          notifyGroupChange(
382              Iterables.concat(servicesByMemberId.asMap().keySet(), ImmutableList.of(memberId)));
383        }
384    
385        private String invalidateNodePath(String deletedPath) {
386          String memberId = group.getMemberId(deletedPath);
387          servicesByMemberId.invalidate(memberId);
388          return memberId;
389        }
390    
391        private final Function<String, ServiceInstance> MAYBE_FETCH_NODE =
392            new Function<String, ServiceInstance>() {
393              @Override public ServiceInstance apply(String memberId) {
394                // This get will trigger a fetch
395                try {
396                  return servicesByMemberId.getUnchecked(memberId);
397                } catch (UncheckedExecutionException e) {
398                  Throwable cause = e.getCause();
399                  if (!(cause instanceof ServiceInstanceDeletedException)) {
400                    Throwables.propagateIfInstanceOf(cause, ServiceInstanceFetchException.class);
401                    throw new IllegalStateException(
402                        "Unexpected error fetching member data for: " + memberId, e);
403                  }
404                  return null;
405                }
406              }
407            };
408    
409        private synchronized void notifyGroupChange(Iterable<String> memberIds) {
410          ImmutableSet<String> newMemberIds = ImmutableSortedSet.copyOf(memberIds);
411          Set<String> existingMemberIds = servicesByMemberId.asMap().keySet();
412    
413          // Ignore no-op state changes except for the 1st when we've seen no group yet.
414          if ((serverSet == null) || !newMemberIds.equals(existingMemberIds)) {
415            SetView<String> deletedMemberIds = Sets.difference(existingMemberIds, newMemberIds);
416            // Implicit removal from servicesByMemberId.
417            existingMemberIds.removeAll(ImmutableSet.copyOf(deletedMemberIds));
418    
419            Iterable<ServiceInstance> serviceInstances = Iterables.filter(
420                Iterables.transform(newMemberIds, MAYBE_FETCH_NODE), Predicates.notNull());
421    
422            notifyServerSetChange(ImmutableSet.copyOf(serviceInstances));
423          }
424        }
425    
426        private void notifyServerSetChange(ImmutableSet<ServiceInstance> currentServerSet) {
427          // ZK nodes may have changed if there was a session expiry for a server in the server set, but
428          // if the server's status has not changed, we can skip any onChange updates.
429          if (!currentServerSet.equals(serverSet)) {
430            if (currentServerSet.isEmpty()) {
431              LOG.warning("server set empty!");
432            } else {
433              if (LOG.isLoggable(Level.INFO)) {
434                if (serverSet == null) {
435                  LOG.info("received initial membership " + currentServerSet);
436                } else {
437                  logChange(Level.INFO, currentServerSet);
438                }
439              }
440            }
441            serverSet = currentServerSet;
442            monitor.onChange(serverSet);
443          }
444        }
445    
446        private void logChange(Level level, ImmutableSet<ServiceInstance> newServerSet) {
447          StringBuilder message = new StringBuilder("server set change: ");
448          if (serverSet.size() != newServerSet.size()) {
449            message.append("from ").append(serverSet.size())
450                .append(" members to ").append(newServerSet.size());
451          }
452    
453          MapDifference<Endpoint, ServiceInstance> changes =
454              Maps.difference(
455                  Maps.uniqueIndex(serverSet, GET_PRIMARY_ENDPOINT),
456                  Maps.uniqueIndex(newServerSet, GET_PRIMARY_ENDPOINT));
457          Joiner joiner = Joiner.on("\n\t\t");
458          Map<Endpoint, ServiceInstance> left = changes.entriesOnlyOnLeft();
459          if (!left.isEmpty()) {
460            message.append("\n\tleft:\n\t\t").append(joiner.join(left.values()));
461          }
462          Map<Endpoint, ServiceInstance> joined = changes.entriesOnlyOnRight();
463          if (!joined.isEmpty()) {
464            message.append("\n\tjoined:\n\t\t").append(joiner.join(joined.values()));
465          }
466          Map<Endpoint, ValueDifference<ServiceInstance>> differing = changes.entriesDiffering();
467          if (!differing.isEmpty()) {
468            message.append("\n\tstatus changed:\n\t\t").append(joiner.join(differing.values()));
469          }
470          LOG.log(level, message.toString());
471        }
472      }
473    
474      private static Codec<ServiceInstance> createCodec(final boolean useJsonEncoding) {
475        final Codec<ServiceInstance> json = JsonCodec.create(ServiceInstance.class, new GsonBuilder()
476            .setExclusionStrategies(JsonCodec.getThriftExclusionStrategy()).create());
477        final Codec<ServiceInstance> thrift = ThriftCodec.create(ServiceInstance.class,
478            ThriftCodec.BINARY_PROTOCOL);
479        final Predicate<byte[]> recognizer = new Predicate<byte[]>() {
480          public boolean apply(byte[] input) {
481            return (input.length > 1 && input[0] == '{' && input[1] == '\"') == useJsonEncoding;
482          }
483        };
484    
485        if (useJsonEncoding) {
486          return CompatibilityCodec.create(json, thrift, 2, recognizer);
487        }
488        return CompatibilityCodec.create(thrift, json, 2, recognizer);
489      }
490    
491      /**
492       * Creates a codec for {@link ServiceInstance} objects that uses Thrift binary encoding, and can
493       * decode both Thrift and JSON encodings.
494       *
495       * @return a new codec instance.
496       */
497      public static Codec<ServiceInstance> createThriftCodec() {
498        return createCodec(false);
499      }
500    
501      /**
502       * Creates a codec for {@link ServiceInstance} objects that uses JSON encoding, and can decode
503       * both Thrift and JSON encodings.
504       *
505       * @return a new codec instance.
506       */
507      public static Codec<ServiceInstance> createJsonCodec() {
508        return createCodec(true);
509      }
510    
511      /**
512       * Returns a codec for {@link ServiceInstance} objects that uses either the Thrift or the JSON
513       * encoding, depending on whether the command line argument <tt>serverset_json_encofing</tt> is
514       * set to <tt>true</tt>, and can decode both Thrift and JSON encodings.
515       *
516       * @return a new codec instance.
517       */
518      public static Codec<ServiceInstance> createDefaultCodec() {
519        return createCodec(ENCODE_JSON.get());
520      }
521    }