问题描述
我需要什么帮助?
调试或更正锁算法的建议。
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.atomicreference;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
public class TestHCLHLock {
static final int MAX_CLUSTERS = 8;
public static class ThreadID {
private static volatile int nextID = 0;
private static ThreadLocalID threadID = new ThreadLocalID();
public static int get() { return threadID.get();}
public static void reset() { nextID = 0; }
public static void set(int value) { threadID.set(value);}
public static int getCluster(int n) {return threadID.get() % n;}
private static class ThreadLocalID extends ThreadLocal<Integer> {
protected synchronized Integer initialValue() { return nextID++; }}}
static class QNode {
private static final int TWS_MASK = 0x80000000;
private static final int SMW_MASK = 0x40000000;
private static final int CLUSTER_MASK = 0x3FFFFFFF;
AtomicInteger state;
public QNode() { state = new AtomicInteger(0); }
boolean waitForGrantOrClusterMaster(int myCluster) {
while (true) {
if (getClusterID() == myCluster && !isTailWhenSpliced() && !isSuccessorMustWait()) {
return true;
} else if (getClusterID() != myCluster || isTailWhenSpliced()) {
return false;
} } }
public void unlock() {
int oldState = 0;
int newState = ThreadID.getCluster(MAX_CLUSTERS) & CLUSTER_MASK;
newState |= SMW_MASK;
newState &= (~TWS_MASK);
do {
oldState = state.get();
} while (!state.compareAndSet(oldState,newState)); }
public int getClusterID() { return state.get() & CLUSTER_MASK; }
public void setClusterID(int clusterID) {
int oldState,newState;
do {
oldState = state.get();
newState = (oldState & ~CLUSTER_MASK) | clusterID;
} while (!state.compareAndSet(oldState,newState)); }
public boolean isSuccessorMustWait() { return (state.get() & SMW_MASK) != 0; }
public void setSuccessorMustWait(boolean successorMustWait) {
int oldState,newState;
do {
oldState = state.get();
if (successorMustWait) {
newState = oldState | SMW_MASK;
} else {
newState = oldState & ~SMW_MASK;
}
} while (!state.compareAndSet(oldState,newState)); }
public boolean isTailWhenSpliced() { return (state.get() & TWS_MASK) != 0; }
public void setTailWhenSpliced(boolean tailWhenSpliced) {
int oldState,newState;
do {
oldState = state.get();
if (tailWhenSpliced) {
newState = oldState | TWS_MASK;
} else {
newState = oldState & ~TWS_MASK;
}
} while (!state.compareAndSet(oldState,newState)); }}
public class HCLHLock implements Lock {
List<atomicreference<QNode>> localQueues;
atomicreference<QNode> globalQueue;
ThreadLocal<QNode> currNode = new ThreadLocal<QNode>() {protected QNode initialValue() { return new QNode(); };};
ThreadLocal<QNode> preNode = new ThreadLocal<QNode>() {protected QNode initialValue() { return null; };};
public HCLHLock() {
localQueues = new ArrayList<atomicreference<QNode>>(MAX_CLUSTERS);
for (int i = 0; i < MAX_CLUSTERS; i++) {
localQueues.add(new atomicreference<QNode>());
}
QNode head = new QNode();
globalQueue = new atomicreference<QNode>(head); }
public void lock() {
QNode myLocalNode = currNode.get();
int myCluster = ThreadID.getCluster(MAX_CLUSTERS);
myLocalNode.setClusterID(myCluster);
atomicreference<QNode> localQueue = localQueues.get(ThreadID.getCluster(MAX_CLUSTERS));
QNode myLocalPred = null;
QNode myGlobalPred = null;
QNode localTail = null;
do {
myLocalPred = localQueue.get();
} while (!localQueue.compareAndSet(myLocalPred,myLocalNode));
if (myLocalPred != null) {
boolean iOwnLock = myLocalPred.waitForGrantOrClusterMaster(myCluster);
preNode.set(myLocalPred);
if (iOwnLock) { return; }
}
do {
myGlobalPred = globalQueue.get();
localTail = localQueue.get();
} while (!globalQueue.compareAndSet(myGlobalPred,localTail));
localTail.setTailWhenSpliced(true);
while (myGlobalPred.isSuccessorMustWait()) {
}
preNode.set(myGlobalPred); }
public void unlock() {
QNode myNode = currNode.get();
myNode.setSuccessorMustWait(false);
QNode myPred= preNode.get();
if (myPred != null){
myPred.unlock();
currNode.set(myPred); } }
public void lockInterruptibly() throws InterruptedException { throw new UnsupportedOperationException(); }
public boolean tryLock() { throw new UnsupportedOperationException(); }
public boolean tryLock(long time,TimeUnit unit) throws InterruptedException { throw new UnsupportedOperationException(); }
public Condition newCondition() { throw new UnsupportedOperationException(); }
}
private final static int THREADS = 32;
private final static int COUNT = 32 * 64;
private final static int PER_THREAD = COUNT / THREADS;
int counter = 0;
HCLHLock instance = new HCLHLock();
public class MyThread extends Thread {
public void run() {
for (int i = 0; i < PER_THREAD; i++) {
instance.lock();
try {
counter = counter + 1;
} finally {
instance.unlock();
}}}}
public void test () throws Exception{
Thread[] thread = new Thread[THREADS];
for (int i = 0; i < THREADS; i++) {
thread[i] = new MyThread();
}
for (int i = 0; i < THREADS; i++) {
thread[i].start();
}
for (int i = 0; i < THREADS; i++) {
thread[i].join();
}
System.out.println(String.format("expect %d,but %d",2048,counter));
}
public static void main(String[] args) throws Exception{
TestHCLHLock test = new TestHCLHLock();
test.test(); }
}
如果您有兴趣,可以查看完整版本的 https://github.com/jiamo/HCLHlock。一种奇怪的行为是:在函数 isSuccessorMustWait
中添加 printf 可以使测试轻松通过。(https://github.com/jiamo/HCLHlock/blob/main/src/main/java/HCLHLock.java#L177)
------------- 将测试代码合并为一个类 ----------
expect 2048,but 2047
主函数总是像上面那样输出。所以锁不能锁定关键代码。但是输出计数几乎接近2048
,这意味着大多数时候算法可以成功锁定它。
------------- 一个更强大的版本,但仍然是错误的 -----------
如果有自己的锁。再等等。
- preNode.set(myLocalPred);
- if (iOwnLock) { return; }
- }
+ if (iOwnLock) {
+ while (myLocalPred.isSuccessorMustWait()) { }
+ preNode.set(myLocalPred);
+ return;
+ } }
------------- 由@mevets 测试 -----------
这个算法不小心出现了死锁。
解决方法
我得到了一些略有不同的结果:
for i in 1 2 3 4 5 6 7 8 9 10; do java TestHCLHLock; done
expect 2048,but 2044
expect 2048,but 2042
expect 2048,but 2016
expect 2048,but 2018
expect 2048,but 2004
expect 2048,but 2043
expect 2048,but 2024
_ 注意它没有完成第 8 次运行。还在疯狂的打转。 因此,我在忙等待循环中放置了一些计数器,并打印了一条消息:
state = c0000001
state = c0000005
state = c0000005
state = c0000007
state = c0000005
state = c0000000
state = c0000004
...
这是在第 170 行循环太多次时打印 state() 的值。 我希望这有帮助。我不是 Java 人,但我认为有调试器之类的东西。
这是在 20 个内核 + 20 个线程的 linux 机器上,我相信是 2x(10+10)。