--- memcache-old.rb	2007-12-21 15:54:13.000000000 +1300
+++ memcache.rb	2007-11-07 23:20:35.000000000 +1300
@@ -180,19 +180,16 @@
   end
 
   ##
-  # Deceremets the value for +key+ by +amount+ and returns the new value.
+  # Decrements the value for +key+ by +amount+ and returns the new value.
   # +key+ must already exist.  If +key+ is not an integer, it is assumed to be
   # 0.  +key+ can not be decremented below 0.
 
   def decr(key, amount = 1)
+    raise MemCacheError, "Update of readonly cache" if @readonly
     server, cache_key = request_setup key
 
-    if @multithread then
-      threadsafe_cache_decr server, cache_key, amount
-    else
-      cache_decr server, cache_key, amount
-    end
-  rescue TypeError, SocketError, SystemCallError, IOError => err
+    cache_decr server, cache_key, amount
+  rescue TypeError => err
     handle_error server, err
   end
 
@@ -203,18 +200,14 @@
   def get(key, raw = false)
     server, cache_key = request_setup key
 
-    value = if @multithread then
-              threadsafe_cache_get server, cache_key
-            else
-              cache_get server, cache_key
-            end
+    value = cache_get server, cache_key
 
     return nil if value.nil?
 
     value = Marshal.load value unless raw
 
     return value
-  rescue TypeError, SocketError, SystemCallError, IOError => err
+  rescue TypeError => err
     handle_error server, err
   end
 
@@ -253,38 +246,31 @@
 
     server_keys.each do |server, keys|
       keys = keys.join ' '
-      values = if @multithread then
-                 threadsafe_cache_get_multi server, keys
-               else
-                 cache_get_multi server, keys
-               end
+      values = cache_get_multi server, keys
       values.each do |key, value|
         results[cache_keys[key]] = Marshal.load value
       end
     end
 
     return results
-  rescue TypeError, SocketError, SystemCallError, IOError => err
+  rescue TypeError => err
     handle_error server, err
   end
 
   ##
-  # Increments the value for +key+ by +amount+ and retruns the new value.
+  # Increments the value for +key+ by +amount+ and returns the new value.
   # +key+ must already exist.  If +key+ is not an integer, it is assumed to be
   # 0.
 
   def incr(key, amount = 1)
+    raise MemCacheError, "Update of readonly cache" if @readonly
     server, cache_key = request_setup key
 
-    if @multithread then
-      threadsafe_cache_incr server, cache_key, amount
-    else
-      cache_incr server, cache_key, amount
-    end
-  rescue TypeError, SocketError, SystemCallError, IOError => err
+    cache_incr server, cache_key, amount
+  rescue TypeError => err
     handle_error server, err
   end
-
+  
   ##
   # Add +key+ to the cache with value +value+ that expires in +expiry+
   # seconds.  If +raw+ is true, +value+ will not be Marshalled.
@@ -295,21 +281,14 @@
   def set(key, value, expiry = 0, raw = false)
     raise MemCacheError, "Update of readonly cache" if @readonly
     server, cache_key = request_setup key
-    socket = server.socket
 
     value = Marshal.dump value unless raw
     command = "set #{cache_key} 0 #{expiry} #{value.size}\r\n#{value}\r\n"
 
-    begin
-      @mutex.lock if @multithread
+    with_socket_management(server) do |socket|
       socket.write command
       result = socket.gets
       raise MemCacheError, $1.strip if result =~ /^SERVER_ERROR (.*)/
-    rescue SocketError, SystemCallError, IOError => err
-      server.close
-      raise MemCacheError, err.message
-    ensure
-      @mutex.unlock if @multithread
     end
   end
 
@@ -324,20 +303,13 @@
   def add(key, value, expiry = 0, raw = false)
     raise MemCacheError, "Update of readonly cache" if @readonly
     server, cache_key = request_setup key
-    socket = server.socket
 
     value = Marshal.dump value unless raw
     command = "add #{cache_key} 0 #{expiry} #{value.size}\r\n#{value}\r\n"
 
-    begin
-      @mutex.lock if @multithread
+    with_socket_management(server) do |socket|
       socket.write command
       socket.gets
-    rescue SocketError, SystemCallError, IOError => err
-      server.close
-      raise MemCacheError, err.message
-    ensure
-      @mutex.unlock if @multithread
     end
   end
 
@@ -345,24 +317,13 @@
   # Removes +key+ from the cache in +expiry+ seconds.
 
   def delete(key, expiry = 0)
-    @mutex.lock if @multithread
-
-    raise MemCacheError, "No active servers" unless active?
-    cache_key = make_cache_key key
-    server = get_server_for_key cache_key
-
-    sock = server.socket
-    raise MemCacheError, "No connection to server" if sock.nil?
+    raise MemCacheError, "Update of readonly cache" if @readonly
+    server, cache_key = request_setup key
 
-    begin
-      sock.write "delete #{cache_key} #{expiry}\r\n"
-      sock.gets
-    rescue SocketError, SystemCallError, IOError => err
-      server.close
-      raise MemCacheError, err.message
+    with_socket_management(server) do |socket|
+      socket.write "delete #{cache_key} #{expiry}\r\n"
+      socket.gets
     end
-  ensure
-    @mutex.unlock if @multithread
   end
 
   ##
@@ -374,15 +335,10 @@
     begin
       @mutex.lock if @multithread
       @servers.each do |server|
-        begin
-          sock = server.socket
-          raise MemCacheError, "No connection to server" if sock.nil?
-          sock.write "flush_all\r\n"
-          result = sock.gets
+        with_socket_management(server) do |socket|
+          socket.write "flush_all\r\n"
+          result = socket.gets
           raise MemCacheError, $2.strip if result =~ /^(SERVER_)?ERROR(.*)/
-        rescue SocketError, SystemCallError, IOError => err
-          server.close
-          raise MemCacheError, err.message
         end
       end
     ensure
@@ -436,14 +392,11 @@
     server_stats = {}
 
     @servers.each do |server|
-      sock = server.socket
-      raise MemCacheError, "No connection to server" if sock.nil?
-
-      value = nil
-      begin
-        sock.write "stats\r\n"
+      with_socket_management(server) do |socket|
+        value = nil # TODO: why is this line here?
+        socket.write "stats\r\n"
         stats = {}
-        while line = sock.gets do
+        while line = socket.gets do
           break if line == "END\r\n"
           if line =~ /^STAT ([\w]+) ([\w\.\:]+)/ then
             name, value = $1, $2
@@ -464,9 +417,6 @@
           end
         end
         server_stats["#{server.host}:#{server.port}"] = stats
-      rescue SocketError, SystemCallError, IOError => err
-        server.close
-        raise MemCacheError, err.message
       end
     end
 
@@ -534,11 +484,12 @@
   # found.
 
   def cache_decr(server, cache_key, amount)
-    socket = server.socket
-    socket.write "decr #{cache_key} #{amount}\r\n"
-    text = socket.gets
-    return nil if text == "NOT_FOUND\r\n"
-    return text.to_i
+    with_socket_management(server) do |socket|
+      socket.write "decr #{cache_key} #{amount}\r\n"
+      text = socket.gets
+      return nil if text == "NOT_FOUND\r\n"
+      return text.to_i
+    end
   end
 
   ##
@@ -546,50 +497,52 @@
   # miss.
 
   def cache_get(server, cache_key)
-    socket = server.socket
-    socket.write "get #{cache_key}\r\n"
-    keyline = socket.gets # "VALUE <key> <flags> <bytes>\r\n"
+    with_socket_management(server) do |socket|
+      socket.write "get #{cache_key}\r\n"
+      keyline = socket.gets # "VALUE <key> <flags> <bytes>\r\n"
 
-    if keyline.nil? then
-      server.close
-      raise MemCacheError, "lost connection to #{server.host}:#{server.port}"
-    end
+      if keyline.nil? then
+        server.close
+        raise MemCacheError, "lost connection to #{server.host}:#{server.port}" # TODO: retry here too
+      end
 
-    return nil if keyline == "END\r\n"
+      return nil if keyline == "END\r\n"
 
-    unless keyline =~ /(\d+)\r/ then
-      server.close
-      raise MemCacheError, "unexpected response #{keyline.inspect}"
+      unless keyline =~ /(\d+)\r/ then
+        server.close
+        raise MemCacheError, "unexpected response #{keyline.inspect}"
+      end
+      value = socket.read $1.to_i
+      socket.read 2 # "\r\n"
+      socket.gets   # "END\r\n"
+      return value
     end
-    value = socket.read $1.to_i
-    socket.read 2 # "\r\n"
-    socket.gets   # "END\r\n"
-    return value
   end
 
   ##
   # Fetches +cache_keys+ from +server+ using a multi-get.
 
   def cache_get_multi(server, cache_keys)
-    values = {}
-    socket = server.socket
-    socket.write "get #{cache_keys}\r\n"
+    with_socket_management(server) do |socket|
+      values = {}
+      socket.write "get #{cache_keys}\r\n"
 
-    while keyline = socket.gets do
-      return values if keyline == "END\r\n"
+      while keyline = socket.gets do
+        return values if keyline == "END\r\n"
 
-      unless keyline =~ /^VALUE (.+) (.+) (.+)/ then
-        server.close
-        raise MemCacheError, "unexpected response #{keyline.inspect}"
+        unless keyline =~ /^VALUE (.+) (.+) (.+)/ then
+          server.close
+          raise MemCacheError, "unexpected response #{keyline.inspect}"
+        end
+
+        key, data_length = $1, $3
+        values[$1] = socket.read data_length.to_i
+        socket.read(2) # "\r\n"
       end
 
-      key, data_length = $1, $3
-      values[$1] = socket.read data_length.to_i
-      socket.read(2) # "\r\n"
+      server.close
+      raise MemCacheError, "lost connection to #{server.host}:#{server.port}" # TODO: retry here too
     end
-
-    server.close
-    raise MemCacheError, "lost connection to #{server.host}:#{server.port}"
   end
 
   ##
@@ -597,17 +550,47 @@
   # found.
 
   def cache_incr(server, cache_key, amount)
-    socket = server.socket
-    socket.write "incr #{cache_key} #{amount}\r\n"
-    text = socket.gets
-    return nil if text == "NOT_FOUND\r\n"
-    return text.to_i
+    with_socket_management(server) do |socket|
+      socket.write "incr #{cache_key} #{amount}\r\n"
+      text = socket.gets
+      return nil if text == "NOT_FOUND\r\n"
+      return text.to_i
+    end
+  end
+  
+  ##
+  # Gets or creates a socket connected to the given server, and yields it
+  # to the block.  If a socket error (SocketError, SystemCallError, IOError) 
+  # or protocol error (MemCacheError) is raised by the block, closes the
+  # socket, attempts to connect again, and retries the block (once).  If
+  # an error is again raised, reraises it as MemCacheError.
+  # If unable to connect to the server (or if in the reconnect wait period),
+  # raises MemCacheError - note that the socket connect code marks a server
+  # dead for a timeout period, so retrying does not apply to connection attempt
+  # failures (but does still apply to unexpectedly lost connections etc.).  
+  # Wraps the whole lot in mutex synchronization if @multithread is true.
+
+  def with_socket_management(server, &block)
+    @mutex.lock if @multithread
+    retried = false
+    begin
+      socket = server.socket
+      raise MemCacheError, "No connection to server (#{server.status})" if socket.nil?
+      block.call(socket)
+    rescue MemCacheError, SocketError, SystemCallError, IOError => err
+      handle_error(server, err) if retried || socket.nil?
+      retried = true
+      retry
+    end
+  ensure
+    @mutex.unlock if @multithread
   end
 
   ##
   # Handles +error+ from +server+.
 
   def handle_error(server, error)
+    raise error if error.is_a?(MemCacheError)
     server.close if server
     new_error = MemCacheError.new error.message
     new_error.set_backtrace error.backtrace
@@ -622,38 +605,9 @@
     raise MemCacheError, 'No active servers' unless active?
     cache_key = make_cache_key key
     server = get_server_for_key cache_key
-    raise MemCacheError, 'No connection to server' if server.socket.nil?
     return server, cache_key
   end
 
-  def threadsafe_cache_decr(server, cache_key, amount) # :nodoc:
-    @mutex.lock
-    cache_decr server, cache_key, amount
-  ensure
-    @mutex.unlock
-  end
-
-  def threadsafe_cache_get(server, cache_key) # :nodoc:
-    @mutex.lock
-    cache_get server, cache_key
-  ensure
-    @mutex.unlock
-  end
-
-  def threadsafe_cache_get_multi(socket, cache_keys) # :nodoc:
-    @mutex.lock
-    cache_get_multi socket, cache_keys
-  ensure
-    @mutex.unlock
-  end
-
-  def threadsafe_cache_incr(server, cache_key, amount) # :nodoc:
-    @mutex.lock
-    cache_incr server, cache_key, amount
-  ensure
-    @mutex.unlock
-  end
-
   ##
   # This class represents a memcached server instance.
 
