Trouble with a simple cross recognizer

After doing the simple SGD demo, I wanted to get a better grasp of the backprop so decided to
implement a simple network which will recognize a ‘X’ sign in a 3x3 array.

ie if its fed the ‘X’ (below) as input, it should output 1 for all the rest 0.

tensor([[1., 0., 1.],
        [0., 1., 0.],
        [1., 0., 1.]])

I am finding that the network does not converge to expected outcome.

I was expecting the learned weights to be like below, but the code does not converge to that.
Appreciate any help or pointers to what needs tweaking? Thanks!

tensor([[1., -1., 1.],
        [-1., 1.,-1.],
        [1., -1., 1.]])

Code below:

#'X' represented as 3x3
x = torch.FloatTensor([[1,0,1],[0,1,0],[1,0,1]])
#count of training samples
sample_count = 10
ratio_of_crosses_in_input = 0.33

#Create sample_count no of samples of 3x3
input_ = torch.rand(sample_count, 3, 3)
#Create a bias element to append to end of image data
bias = torch.ones(sample_count,1)
#reshape the image sample from 3x3 => 9x1 to make it a linear 1D array
input_1x = input_.reshape(sample_count,9)

#Add bias as last element, ie 10th element
input_with_bias = torch.cat((input_1x,bias), 1)
#Make % of the samples of type 'X'
input_with_bias[int(sample_count*(1-ratio_of_crosses_in_input)):,:] = torch.cat((x.reshape(1,9)[0],     torch.ones(1)), 0)

#Create expected result tensor with all outputs corresponding to 'X' set as 1
expected_y = torch.zeros(sample_count, 1)
expected_y[int(sample_count*(1-ratio_of_crosses_in_input)):] = 1

#Create single layer nn, initialized to random
layer1 = torch.rand(10) - 0.5
#MSE
def mse(y_hat, y): return ((y_hat -y)**2).mean()

lr = 0.05
sig = torch.nn.Sigmoid()
layer1 = nn.Parameter(layer1)

def update():
    y_hat = input_with_bias@layer1
    y_hat = sig(y_hat)
    loss = mse( y_hat, expected_y)
    loss.backward()
    if i % 5000 == 0:
      print("Loss:")
      print (loss)    
  
    with torch.no_grad():
      layer1.sub_(lr * layer1.grad)
      layer1.grad.zero_()

#Update the weights 
for i in range(50000):
  update()