用TensorFlow进行逻辑回归(四)

更高维的数据

import numpy as np

from sklearn.datasets import *

import pandas as pd

import tensorflow as tf

%matplotlib inline

import matplotlib

import matplotlib.pyplot as plt

from random import *

from sklearn.metrics import accuracy_score

from scipy.special import logit

Xinput, yinput = make_blobs(n_samples = 1000, centers = 2, n_features = 2)

df = pd.DataFrame(dict(x=Xinput[:,0], y = Xinput[:,1], label = yinput))

Xinput[0:5,]

yinput[0:10,]

colors = {0:'red', 1:'blue', 2:'green'}

fig, ax = plt.subplots()

plt.scatter(df[yinput==1]['x'], df[yinput==1]['y'], color = 'red')

plt.scatter(df[yinput==0]['x'], df[yinput==0]['y'], color = 'blue')

plt.tick_params(labelsize=14)

plt.xlabel('x', fontsize = 16)

plt.ylabel('y', fontsize = 16)

plt.show()

#tf.reset_default_graph()

X = tf.cast( Xinput.reshape(-1,2),tf.float32)

Y = tf.cast( yinput.reshape(-1,1),tf.float32)

def run_logistic_model(learning_r, training_epochs, train_obs, train_labels, debug = False):

optimizer = tf.optimizers.SGD(learning_r)

#optimizer= tf.keras.optimizers.Adam(learning_rate=learning_r)

cost_history = np.empty(shape=[0], dtype = float)

loss_list_train = []

loss_list_valid = []

acc_list_train = []

acc_valid_train = []

W = tf.Variable(tf.random.normal((2, 1)))

b = tf.Variable(tf.random.normal((1,)))

for epoch in range(training_epochs+1):

with tf.GradientTape() as tape:

PRED=1/(1+tf.exp(-tf.matmul(X,W)+b))

Loss=-tf.reduce_mean(Y*tf.math.log(PRED)+(1-Y)*tf.math.log(1-PRED))

gradients=tape.gradient(Loss,[W,b])

#_, summary, loss = sess.run([train_op, merged, l], feed_dict=feed_dict)

#y_logit = tf.matmul(X, W) + b

the sigmoid gives the class probability of 1

#y_one_prob = tf.sigmoid(y_logit)

Rounding P(y=1) will give the correct prediction.

#y_pred = tf.round(y_one_prob)

#entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=y_logit, labels=Y)

Sum all contributions

#l = tf.reduce_sum(entropy)

#gradients=tape.gradient(l,[W,b])

optimizer.apply_gradients(zip(gradients, [W, b]))

#Acc=tf.reduce_mean(tf.cast(tf.equal(tf.cast(tf.where(PRED.numpy()<0.5,0,1),tf.float32),Y),tf.float32))

#loss_train =loss(train_obs,train_labels,W,b).numpy()

#loss_valid =loss(valid_x,valid_y,W,B).numpy()

#acc_train=accuracy(train_obs,train_labels,W,b).numpy()

#acc_valid=accuracy(valid_x,valid_y,W,B).numpy()

loss_list_train.append(Loss)

#loss_list_valid.append(loss_valid)

#acc_list_train.append(Acc)

#acc_valid_train.append(acc_valid)

#print("epoch={:3d},train_loss={:.4f},train_acc={:.4f},val_loss={:.4f},val_acc={:.4f}".format(epoch+1,loss_train,acc_train,loss_valid,acc_valid))

if (epoch % 1000 == 0) & debug:

print("Reached epoch",epoch,"cost J =", str.format('{0:.6f}', Loss))

return loss_list_train,W,b

cost_history,W,b = run_logistic_model(learning_r = 0.01,

training_epochs = 10000,

train_obs =X,

train_labels = Y,

debug = True)

plt.rc('font', family='arial')

plt.rc('xtick', labelsize='x-small')

plt.rc('ytick', labelsize='x-small')

plt.tight_layout()

fig = plt.figure(figsize=(10,8))

ax = fig.add_subplot(1, 1, 1)

ax.plot(cost_history, ls='solid', color = 'black', label = '\\gamma = 0.01')

ax.set_xlabel('epochs', fontsize = 16)

ax.set_ylabel('Cost function J', fontsize = 16)

plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize = 16)

plt.tick_params(labelsize=16)

colors = {0:'red', 1:'blue', 2:'green'}

fig, ax = plt.subplots()

plt.scatter(df[yinput==1]['x'], df[yinput==1]['y'], color = 'red')

plt.scatter(df[yinput==0]['x'], df[yinput==0]['y'], color = 'blue')

plt.tick_params(labelsize=14)

plt.xlabel('x', fontsize = 16)

plt.ylabel('y', fontsize = 16)

plt.xlim(-10, 7.5)

plt.ylim(-10, 20)

x_left = -10

y_left = (1./W[1]) * (-b + logit(.5) - W[0]*x_left)

x_right = 7.5

y_right = (1./W[1]) * (-b + logit(.5) - W[0]*x_right)

plt.plot([x_left, x_right], [y_left, y_right], color='k')

plt.show()

#plt.savefig("logistic_pred.png")

相关推荐
ar01235 小时前
AR远程协助作用
人工智能·ar
北京青翼科技5 小时前
PCIe接口-高速模拟采集—高性能计算卡-青翼科技高品质军工级数据采集板-打造专业工业核心板
图像处理·人工智能·fpga开发·信号处理·智能硬件
软件聚导航5 小时前
马年、我用AI写了个“打工了马” 小程序
人工智能·ui·微信小程序
陈天伟教授6 小时前
人工智能应用-机器听觉:7. 统计合成法
人工智能·语音识别
笨蛋不要掉眼泪6 小时前
Spring Boot集成LangChain4j:与大模型对话的极速入门
java·人工智能·后端·spring·langchain
昨夜见军贴06166 小时前
IACheck AI审核技术赋能消费认证:为智能宠物喂食器TELEC报告构筑智能合规防线
人工智能·宠物
DisonTangor7 小时前
阿里开源语音识别模型——Qwen3-ASR
人工智能·开源·语音识别
万事ONES7 小时前
ONES 签约北京高级别自动驾驶示范区专设国有运营平台——北京车网
人工智能·机器学习·自动驾驶
qyr67897 小时前
深度解析:3D细胞培养透明化试剂供应链与主要制造商分布
大数据·人工智能·3d·市场分析·市场报告·3d细胞培养·细胞培养
软件开发技术深度爱好者7 小时前
浅谈人工智能(AI)对个人发展的影响
人工智能